mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 19:55:06 +08:00
Compare commits
440 Commits
0.15.1
...
fix/versio
| Author | SHA1 | Date | |
|---|---|---|---|
| 3d9c8d76b9 | |||
| 9ac6662635 | |||
| 316c418a01 | |||
| 4886e5ae96 | |||
| 7f4a8b955d | |||
| 83d0142641 | |||
| 56c7f49625 | |||
| 7c1d842cfe | |||
| 2ea3b64a45 | |||
| 824f8d8994 | |||
| 31c17e6378 | |||
| 50cfb7c9ec | |||
| 8281c688ca | |||
| ad9d6eb5f4 | |||
| aa3dc9002c | |||
| 4a43e165fb | |||
| 4d25b598f9 | |||
| 3e9c3d0bb7 | |||
| fec3bb4469 | |||
| d4a09805a3 | |||
| 7e1d9894fb | |||
| a8a8a5513c | |||
| 470e72c820 | |||
| beebba0340 | |||
| 4e27d82d68 | |||
| cdeaf3f70b | |||
| 24839bb3e1 | |||
| 1650dbfbb1 | |||
| fd11817044 | |||
| 6642fc6012 | |||
| 2710242982 | |||
| 1de84fdda0 | |||
| 3befbc1d68 | |||
| 62c413aca5 | |||
| 6887b501b8 | |||
| f93bf131ab | |||
| ef1f429437 | |||
| c966bf1474 | |||
| 899df30bf6 | |||
| 8d8d3e3f2f | |||
| 5f0fa38ec6 | |||
| cc1fe70d34 | |||
| 15ee1e11be | |||
| c8b4a76530 | |||
| 6ee4eba86b | |||
| 357d2e8be8 | |||
| b5accda3fe | |||
| de4752a16b | |||
| 60427f1adf | |||
| 1a313c868d | |||
| 0b32b1988f | |||
| e56c051d97 | |||
| 0a6b4d01d7 | |||
| 98b139c680 | |||
| f0a3c14adb | |||
| 51947575c2 | |||
| cb8debee3e | |||
| d56079a549 | |||
| c08b451874 | |||
| ac336ff359 | |||
| 4cbd511cd7 | |||
| c03adcb154 | |||
| 04dade2f9b | |||
| f69220ca96 | |||
| a5e24ff6d3 | |||
| 71976f9192 | |||
| 39ec6c8025 | |||
| e370045ac4 | |||
| 28edbbac0b | |||
| 782abcecd8 | |||
| 4deb02fc2c | |||
| f967180dc2 | |||
| cead13cbc3 | |||
| 078c151065 | |||
| 17babca362 | |||
| 8efed8858c | |||
| 0d411a0b5a | |||
| 13f0c01f93 | |||
| 3c014f3ae5 | |||
| e4c4490175 | |||
| 94a62f6b4e | |||
| d76af08784 | |||
| f748d6c7c4 | |||
| 76e24d91c0 | |||
| 5ce4ddc0ed | |||
| 491d641485 | |||
| 172c5f19cc | |||
| b7d168ac59 | |||
| fb309462ad | |||
| b56d2b739b | |||
| fb7b2c8ff3 | |||
| c3440a27fb | |||
| ff3d3f71fb | |||
| 9685b9a302 | |||
| 07c7b7b886 | |||
| 8d75abc976 | |||
| aa6452b3bf | |||
| 3799d40937 | |||
| d2ff8a2381 | |||
| 5f51a19de2 | |||
| 71e0bfcbd8 | |||
| d815c74fc5 | |||
| 107e44c8fb | |||
| adf7eea7fe | |||
| 6e73ad2fc6 | |||
| 06412b37d3 | |||
| 63665a5ff1 | |||
| 05a43e3e80 | |||
| 83fdb42520 | |||
| cbf405beea | |||
| af2aede783 | |||
| e359ace633 | |||
| a5555f90c6 | |||
| 78664c8903 | |||
| 45070535bd | |||
| 048e8cf0d1 | |||
| 598d208e54 | |||
| 8102cee8df | |||
| c9eb9c14d7 | |||
| e77cd87842 | |||
| ac5e3caebc | |||
| 23066a9ba8 | |||
| 0249f15609 | |||
| 2f523dd29f | |||
| b34d815883 | |||
| 51cc63d9ce | |||
| 430af95b53 | |||
| 0164d1410a | |||
| cbc5045b7a | |||
| b980c07af8 | |||
| e231cf2c48 | |||
| 80d8e47e42 | |||
| fee4dd7d7a | |||
| 00cf5f3841 | |||
| 9ee0c7a694 | |||
| 6ee7ca1890 | |||
| f589397f25 | |||
| ee080dddf9 | |||
| ee6841648c | |||
| 5a57dad93c | |||
| 4199998c7e | |||
| 39656f7f84 | |||
| bf39e314d8 | |||
| 8cc4c109d0 | |||
| a1cdca02e3 | |||
| 1b21d7513d | |||
| d5c708c62b | |||
| 342d4060ff | |||
| 05232d36f0 | |||
| 636dde94c7 | |||
| 75fe785d88 | |||
| a61da6cf95 | |||
| 93c3699128 | |||
| 6357450a7a | |||
| 6339706c68 | |||
| 65a4cb769b | |||
| 63206a7967 | |||
| 9a6f120e5c | |||
| dedc1b0c3a | |||
| 46bb246ecc | |||
| 3c628d0c26 | |||
| c2983ecbb7 | |||
| 527c1cf608 | |||
| 93786f516c | |||
| a175d6b2d7 | |||
| 296fd82bbf | |||
| 4ccd571364 | |||
| ae72514cb4 | |||
| 16b49ac436 | |||
| c377eb8c28 | |||
| 337eff2b79 | |||
| b7ac287fec | |||
| c1a85b0208 | |||
| 01efdee1dd | |||
| 0af9c4fd9d | |||
| ee38bd8817 | |||
| 86291c13e4 | |||
| 7679a57f18 | |||
| dcf19549cb | |||
| 574a6c1ded | |||
| c34877aecf | |||
| 632b2bac2a | |||
| 77a62f33b3 | |||
| ad899844a1 | |||
| b10d6051ba | |||
| fb44cd87e7 | |||
| 89af726985 | |||
| 6f2d5ff099 | |||
| 687455ca31 | |||
| 8c5928da2f | |||
| 772009115d | |||
| 0452dfd029 | |||
| eead6abe85 | |||
| 5c6d919a4a | |||
| e39eddab03 | |||
| db726e02a0 | |||
| e4b8220bc2 | |||
| 08cfcb453c | |||
| 992e1eedde | |||
| c2ce8e638e | |||
| ba3659a792 | |||
| 965fabd578 | |||
| accbbae755 | |||
| 49bd1a7a49 | |||
| 5ff9cee326 | |||
| 200f9af5d8 | |||
| 1443fd6739 | |||
| e63ae36665 | |||
| cfa7c89dfe | |||
| a6835ac64d | |||
| a700b49461 | |||
| 22df86fe8a | |||
| 24734009b9 | |||
| 959d060a44 | |||
| 4492295683 | |||
| 88fac0d898 | |||
| 8b30099672 | |||
| 97a3727962 | |||
| 2cb640de15 | |||
| fb4ee813c7 | |||
| 6300e506fb | |||
| a0543ab8fb | |||
| 634cb6233e | |||
| db68ae4a73 | |||
| d25e79e794 | |||
| 183b943803 | |||
| 5828abcd62 | |||
| 56bd0dedfe | |||
| f6136427a4 | |||
| 21fd58caf9 | |||
| 9a69d03fbe | |||
| 1d2118fc5d | |||
| bc0724b499 | |||
| 5cdbfe2f41 | |||
| 5fd82084f9 | |||
| f0637ba332 | |||
| 115c9486c3 | |||
| 8b5231b7ee | |||
| 38cae29757 | |||
| 7a2b2a04c9 | |||
| fe677cc5f9 | |||
| 28c9ec3f4f | |||
| 6baa98f166 | |||
| e9d69f020a | |||
| 3c89d45a2d | |||
| baab81714e | |||
| 507bb3549a | |||
| 2d1e5fb4e0 | |||
| b9198639e2 | |||
| 43c7739b88 | |||
| f65d577f54 | |||
| b88145096f | |||
| 33219e850a | |||
| 3040d538f7 | |||
| 4e1af81e11 | |||
| 56e19fd8f5 | |||
| d330d31ee5 | |||
| 0858108423 | |||
| 2cd976846a | |||
| 5d2c88ef59 | |||
| fe3cde973e | |||
| 794f495ef2 | |||
| 0dda682033 | |||
| 01d8d10f1c | |||
| c711c5e36e | |||
| 1e27557865 | |||
| 2d9632d8b9 | |||
| 7e42de1e7b | |||
| bd674d27be | |||
| 5735761920 | |||
| 405b704f02 | |||
| f38abaaa6a | |||
| c8a5fee622 | |||
| fe1c0ac602 | |||
| e79c3e4531 | |||
| 3ea3df7189 | |||
| b01e7d778e | |||
| 7c45859594 | |||
| aa9fd76072 | |||
| e7d947379f | |||
| 8cd386f2c1 | |||
| 987e1b9ced | |||
| 81a77d0623 | |||
| ac1f93e3d5 | |||
| 0d5c0b4fe4 | |||
| d1c480a7d8 | |||
| 007b561e32 | |||
| c100f24f7d | |||
| d92cb994a9 | |||
| 413326905e | |||
| 5605ff9803 | |||
| 84b7a4607a | |||
| 10cc4e758c | |||
| 8070be9b76 | |||
| f1f1baae9c | |||
| f20c9ef763 | |||
| f798add31c | |||
| 8c2dbe876f | |||
| 6fd0a55b00 | |||
| bb58f5c6e5 | |||
| 18edeb8e0a | |||
| 459cb9dd72 | |||
| f9e2c738b0 | |||
| 739e15f88b | |||
| 5bf86ff66d | |||
| c657378d06 | |||
| 685e8cdc7d | |||
| d36dece0af | |||
| 5f61aa85db | |||
| e5837b88e0 | |||
| ffdc6f5c60 | |||
| 99c8f364ae | |||
| a0a1243c90 | |||
| b916b4064a | |||
| dea2962a79 | |||
| 1450e5d5cb | |||
| 43a2d4335b | |||
| 11270a7ef2 | |||
| 53e1b45d40 | |||
| bedbd658fe | |||
| 7b62b5578e | |||
| ccbe42eb5f | |||
| 45f8651a3d | |||
| 7754431a34 | |||
| fa7215cfea | |||
| 678c89891a | |||
| beebcbd962 | |||
| 8495ed3348 | |||
| 31cca4a849 | |||
| 43ffccc8fd | |||
| a81293cf5a | |||
| 276701e1b7 | |||
| 8e1cf3233c | |||
| dd551e6ca8 | |||
| ae1eeb9b2a | |||
| b58f8dd7b4 | |||
| 118fa66567 | |||
| 699d41deec | |||
| dd0462c1dc | |||
| a470e0e60e | |||
| 2622159763 | |||
| dfaf639790 | |||
| ae96f66a08 | |||
| 570b7d18ac | |||
| a9c21ef929 | |||
| e27a03ae15 | |||
| 56b7853afe | |||
| e12f4009d3 | |||
| 6dfc31a542 | |||
| c9f80b46a1 | |||
| 0025b27200 | |||
| 0dd05d7b6d | |||
| 7c83d5ce76 | |||
| a57f60a6e0 | |||
| 2f36692bf9 | |||
| bcdb407be8 | |||
| d4e007f9db | |||
| 8563155d1b | |||
| 8236373498 | |||
| 196bfeaaf4 | |||
| 957ab093c9 | |||
| e9e5c8806a | |||
| c8bc3892b3 | |||
| 735e57b73a | |||
| 635a53ea38 | |||
| 7b76b1ff82 | |||
| 47c8824be6 | |||
| 1c3213184e | |||
| d9cced8419 | |||
| c3359a9291 | |||
| 2da32e49d0 | |||
| 1837692a66 | |||
| 5dcd25a613 | |||
| 507fff0259 | |||
| 0ad9dbea63 | |||
| 4c28034224 | |||
| 1d575524c3 | |||
| dc255cc154 | |||
| ea497f828f | |||
| 153dc5b3f3 | |||
| a91951b374 | |||
| 68c10a1672 | |||
| 592f85f7a9 | |||
| cda9f6ec6b | |||
| 64706c709c | |||
| 9722e6bcb1 | |||
| 1907d791e1 | |||
| fb3a701c86 | |||
| 947bfdc807 | |||
| 7a3e756020 | |||
| 435e71eb60 | |||
| 91cb80f795 | |||
| 3c1d32e3ac | |||
| eef79a5196 | |||
| 2223dfb266 | |||
| 9693b5ad0c | |||
| d4bf575d0a | |||
| 73ce692e24 | |||
| 661392eaef | |||
| c472ea6c67 | |||
| 4eaba3049a | |||
| 00d1c45518 | |||
| 87c746f6bb | |||
| 70c001436e | |||
| cf73374c1b | |||
| b0d53c0ac4 | |||
| 9c7bcd5abc | |||
| b7c5abc5dd | |||
| de01ca8d55 | |||
| 60e75dc748 | |||
| 279dee485d | |||
| db8bf2a85e | |||
| 46ba16fe90 | |||
| 886a160115 | |||
| cf4e9f317e | |||
| 1fa3b9cfd8 | |||
| 50a5cfe56a | |||
| ece82b87bf | |||
| 12ea085e22 | |||
| 41ed2e0cc2 | |||
| 113ff27d07 | |||
| ec711d094d | |||
| a073de44e9 | |||
| 6ce02b07d3 | |||
| f47712beae | |||
| 4a8d3c54ca | |||
| c8b0160ea9 | |||
| 531ffaec4f | |||
| c28998a6f0 | |||
| 4b4741f7ed | |||
| 25b8a512bf | |||
| 02d26818ad | |||
| 31e8b134d1 | |||
| d52476c1c9 | |||
| f29b44acd8 | |||
| ed7fcc5f7d | |||
| c6f34f5c17 | |||
| e1db77eec2 | |||
| 563d81277b | |||
| 364df36ac4 |
@ -1,11 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd web && npm install
|
||||
npm add -g pnpm@9.12.2
|
||||
cd web && pnpm install
|
||||
pipx install poetry
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
|
||||
|
||||
|
||||
2
.github/actions/setup-poetry/action.yml
vendored
2
.github/actions/setup-poetry/action.yml
vendored
@ -8,7 +8,7 @@ inputs:
|
||||
poetry-version:
|
||||
description: Poetry version to set up
|
||||
required: true
|
||||
default: '1.8.4'
|
||||
default: '2.0.1'
|
||||
poetry-lockfile:
|
||||
description: Path to the Poetry lockfile to restore cache from
|
||||
required: true
|
||||
|
||||
19
.github/workflows/api-tests.yml
vendored
19
.github/workflows/api-tests.yml
vendored
@ -4,6 +4,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
@ -42,25 +43,17 @@ jobs:
|
||||
run: poetry install -C api --with dev
|
||||
|
||||
- name: Check dependencies in pyproject.toml
|
||||
run: poetry run -C api bash dev/pytest/pytest_artifacts.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
run: poetry run -P api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run mypy
|
||||
run: |
|
||||
pushd api
|
||||
poetry run python -m mypy --install-types --non-interactive .
|
||||
popd
|
||||
poetry run -C api python -m mypy --install-types --non-interactive .
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
@ -80,4 +73,4 @@ jobs:
|
||||
ssrf_proxy
|
||||
|
||||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@ -5,6 +5,7 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "plugins/beta"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
|
||||
1
.github/workflows/db-migration-test.yml
vendored
1
.github/workflows/db-migration-test.yml
vendored
@ -4,6 +4,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/migrations/**
|
||||
- .github/workflows/db-migration-test.yml
|
||||
|
||||
47
.github/workflows/docker-build.yml
vendored
Normal file
47
.github/workflows/docker-build.yml
vendored
Normal file
@ -0,0 +1,47 @@
|
||||
name: Build docker image
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/Dockerfile
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
context: "api"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
context: "api"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
context: "web"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
push: false
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
46
.github/workflows/style.yml
vendored
46
.github/workflows/style.yml
vendored
@ -4,6 +4,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
|
||||
concurrency:
|
||||
group: style-${{ github.head_ref || github.run_id }}
|
||||
@ -38,12 +39,12 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
poetry run -C api ruff --version
|
||||
poetry run -C api ruff check ./api
|
||||
poetry run -C api ruff format --check ./api
|
||||
poetry run -C api ruff check ./
|
||||
poetry run -C api ruff format --check ./
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
@ -66,22 +67,55 @@ jobs:
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 20
|
||||
cache: yarn
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn run lint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
docker/.env.example
|
||||
docker/docker-compose-template.yaml
|
||||
docker/docker-compose.yaml
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd docker
|
||||
./generate_docker_compose
|
||||
|
||||
- name: Check for changes
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: git diff --exit-code
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
@ -107,7 +141,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
DEFAULT_BRANCH: main
|
||||
DEFAULT_BRANCH: plugins/beta
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
IGNORE_GENERATED_FILES: true
|
||||
IGNORE_GITIGNORED_FILES: true
|
||||
|
||||
6
.github/workflows/tool-test-sdks.yaml
vendored
6
.github/workflows/tool-test-sdks.yaml
vendored
@ -32,10 +32,10 @@ jobs:
|
||||
with:
|
||||
node-version: ${{ matrix.node-version }}
|
||||
cache: ''
|
||||
cache-dependency-path: 'yarn.lock'
|
||||
cache-dependency-path: 'pnpm-lock.yaml'
|
||||
|
||||
- name: Install Dependencies
|
||||
run: yarn install
|
||||
run: pnpm install
|
||||
|
||||
- name: Test
|
||||
run: yarn test
|
||||
run: pnpm test
|
||||
|
||||
@ -38,11 +38,11 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
run: pnpm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -70,4 +70,4 @@ jobs:
|
||||
tidb
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -34,13 +34,13 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 20
|
||||
cache: yarn
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn test
|
||||
run: pnpm test
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -175,6 +175,7 @@ docker/volumes/pgvector/data/*
|
||||
docker/volumes/pgvecto_rs/data/*
|
||||
docker/volumes/couchbase/*
|
||||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
@ -193,3 +194,9 @@ api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
@ -25,6 +25,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="seguir en X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="seguir en LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Descargas de Docker" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="suivre sur X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="suivre sur LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Tirages Docker" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="X(Twitter)でフォロー"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="LinkedInでフォロー"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -25,6 +25,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -22,6 +22,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="X(Twitter)'da takip et"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="LinkedIn'da takip et"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Çekmeleri" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
@ -62,8 +65,6 @@ Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edi
|
||||

|
||||
|
||||
|
||||
Özür dilerim, haklısınız. Daha anlamlı ve akıcı bir çeviri yapmaya çalışayım. İşte güncellenmiş çeviri:
|
||||
|
||||
**3. Prompt IDE**:
|
||||
Komut istemlerini oluşturmak, model performansını karşılaştırmak ve sohbet tabanlı uygulamalara metin-konuşma gibi ek özellikler eklemek için kullanıcı dostu bir arayüz.
|
||||
|
||||
@ -150,8 +151,6 @@ Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edi
|
||||
## Dify'ı Kullanma
|
||||
|
||||
- **Cloud </br>**
|
||||
İşte verdiğiniz metnin Türkçe çevirisi, kod bloğu içinde:
|
||||
-
|
||||
Herkesin sıfır kurulumla denemesi için bir [Dify Cloud](https://dify.ai) hizmeti sunuyoruz. Bu hizmet, kendi kendine dağıtılan versiyonun tüm yeteneklerini sağlar ve sandbox planında 200 ücretsiz GPT-4 çağrısı içerir.
|
||||
|
||||
- **Dify Topluluk Sürümünü Kendi Sunucunuzda Barındırma</br>**
|
||||
@ -177,8 +176,6 @@ GitHub'da Dify'a yıldız verin ve yeni sürümlerden anında haberdar olun.
|
||||
>- RAM >= 4GB
|
||||
|
||||
</br>
|
||||
İşte verdiğiniz metnin Türkçe çevirisi, kod bloğu içinde:
|
||||
|
||||
Dify sunucusunu başlatmanın en kolay yolu, [docker-compose.yml](docker/docker-compose.yaml) dosyamızı çalıştırmaktır. Kurulum komutunu çalıştırmadan önce, makinenizde [Docker](https://docs.docker.com/get-docker/) ve [Docker Compose](https://docs.docker.com/compose/install/)'un kurulu olduğundan emin olun:
|
||||
|
||||
```bash
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="theo dõi trên X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="theo dõi trên LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
.env
|
||||
*.env.*
|
||||
|
||||
storage/generate_files/*
|
||||
storage/privkeys/*
|
||||
storage/tools/*
|
||||
storage/upload_files/*
|
||||
|
||||
# Logs
|
||||
logs
|
||||
@ -9,6 +12,8 @@ logs
|
||||
|
||||
# jetbrains
|
||||
.idea
|
||||
.mypy_cache
|
||||
.ruff_cache
|
||||
|
||||
# venv
|
||||
.venv
|
||||
@ -409,7 +409,6 @@ MAX_VARIABLE_SIZE=204800
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
|
||||
# Celery beat configuration
|
||||
CELERY_BEAT_SCHEDULER_TIME=1
|
||||
|
||||
@ -422,6 +421,22 @@ POSITION_PROVIDER_PINS=
|
||||
POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
||||
# Plugin configuration
|
||||
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
||||
PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
|
||||
@ -53,10 +53,12 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
|
||||
@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
||||
WORKDIR /app/api
|
||||
|
||||
# Install Poetry
|
||||
ENV POETRY_VERSION=1.8.4
|
||||
ENV POETRY_VERSION=2.0.1
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
|
||||
@ -48,16 +48,18 @@ ENV TZ=UTC
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.19+dfsg-1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
RUN \
|
||||
apt-get update \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
# expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
fonts-noto-cjk \
|
||||
# install libmagic to support the use of python-magic guess MIMETYPE
|
||||
libmagic1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@ -69,6 +71,10 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
# Download nltk data
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
|
||||
|
||||
# Copy source code
|
||||
COPY . /app/api/
|
||||
|
||||
@ -76,7 +82,6 @@ COPY . /app/api/
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA=${COMMIT_SHA}
|
||||
|
||||
|
||||
@ -79,5 +79,5 @@
|
||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
|
||||
```bash
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
poetry run -P api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
@ -25,6 +25,8 @@ from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@ -524,7 +526,7 @@ def add_qdrant_doc_id_index(field: str):
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
click.echo(click.style("Failed to create Qdrant client.", fg="red"))
|
||||
|
||||
click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green"))
|
||||
@ -593,7 +595,7 @@ def upgrade_db():
|
||||
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Failed to execute database migration")
|
||||
finally:
|
||||
lock.release()
|
||||
@ -639,7 +641,7 @@ where sites.id is null limit 1000"""
|
||||
account = accounts[0]
|
||||
print("Fixing missing site for app {}".format(app.id))
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red"))
|
||||
logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
|
||||
@ -649,3 +651,68 @@ where sites.id is null limit 1000"""
|
||||
break
|
||||
|
||||
click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
|
||||
|
||||
|
||||
@click.command("migrate-data-for-plugin", help="Migrate data for plugin.")
|
||||
def migrate_data_for_plugin():
|
||||
"""
|
||||
Migrate data for plugin.
|
||||
"""
|
||||
click.echo(click.style("Starting migrate data for plugin.", fg="white"))
|
||||
|
||||
PluginDataMigration.migrate()
|
||||
|
||||
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("extract-plugins", help="Extract plugins.")
|
||||
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
|
||||
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
|
||||
def extract_plugins(output_file: str, workers: int):
|
||||
"""
|
||||
Extract plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting extract plugins.", fg="white"))
|
||||
|
||||
PluginMigration.extract_plugins(output_file, workers)
|
||||
|
||||
click.echo(click.style("Extract plugins completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("extract-unique-identifiers", help="Extract unique identifiers.")
|
||||
@click.option(
|
||||
"--output_file",
|
||||
prompt=True,
|
||||
help="The file to store the extracted unique identifiers.",
|
||||
default="unique_identifiers.json",
|
||||
)
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
def extract_unique_plugins(output_file: str, input_file: str):
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting extract unique plugins.", fg="white"))
|
||||
|
||||
PluginMigration.extract_unique_plugins_to_file(input_file, output_file)
|
||||
|
||||
click.echo(click.style("Extract unique plugins completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("install-plugins", help="Install plugins.")
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
def install_plugins(input_file: str, output_file: str):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
click.echo(click.style("Starting install plugins.", fg="white"))
|
||||
|
||||
PluginMigration.install_plugins(input_file, output_file)
|
||||
|
||||
click.echo(click.style("Install plugins completed.", fg="green"))
|
||||
|
||||
@ -134,6 +134,60 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
"""
|
||||
|
||||
PLUGIN_DAEMON_URL: HttpUrl = Field(
|
||||
description="Plugin API URL",
|
||||
default="http://localhost:5002",
|
||||
)
|
||||
|
||||
PLUGIN_DAEMON_KEY: str = Field(
|
||||
description="Plugin API key",
|
||||
default="plugin-api-key",
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||
description="Plugin Remote Install Host",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_PORT: PositiveInt = Field(
|
||||
description="Plugin Remote Install Port",
|
||||
default=5003,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_PACKAGE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for plugin packages in bytes",
|
||||
default=15728640,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_BUNDLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for plugin bundles in bytes",
|
||||
default=15728640 * 12,
|
||||
)
|
||||
|
||||
|
||||
class MarketplaceConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for marketplace
|
||||
"""
|
||||
|
||||
MARKETPLACE_ENABLED: bool = Field(
|
||||
description="Enable or disable marketplace",
|
||||
default=True,
|
||||
)
|
||||
|
||||
MARKETPLACE_API_URL: HttpUrl = Field(
|
||||
description="Marketplace API URL",
|
||||
default="https://marketplace.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -146,7 +200,7 @@ class EndpointConfig(BaseSettings):
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
|
||||
description="Base URL for the console web interface,used for frontend references and CORS configuration",
|
||||
default="",
|
||||
)
|
||||
|
||||
@ -160,6 +214,10 @@ class EndpointConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
ENDPOINT_URL_TEMPLATE: str = Field(
|
||||
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
|
||||
)
|
||||
|
||||
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
@ -498,6 +556,11 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
FORGOT_PASSWORD_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying password reset after exceeding the rate limit.",
|
||||
default=86400,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -788,6 +851,8 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
|
||||
@ -1,9 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt
|
||||
from pydantic import Field, NonNegativeInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class HostedCreditConfig(BaseSettings):
|
||||
HOSTED_MODEL_CREDIT_CONFIG: str = Field(
|
||||
description="Model credit configuration in format 'model:credits,model:credits', e.g., 'gpt-4:20,gpt-4o:10'",
|
||||
default="",
|
||||
)
|
||||
|
||||
def get_model_credits(self, model_name: str) -> int:
|
||||
"""
|
||||
Get credit value for a specific model name.
|
||||
Returns 1 if model is not found in configuration (default credit).
|
||||
|
||||
:param model_name: The name of the model to search for
|
||||
:return: The credit value for the model
|
||||
"""
|
||||
if not self.HOSTED_MODEL_CREDIT_CONFIG:
|
||||
return 1
|
||||
|
||||
try:
|
||||
credit_map = dict(
|
||||
item.strip().split(":", 1) for item in self.HOSTED_MODEL_CREDIT_CONFIG.split(",") if ":" in item
|
||||
)
|
||||
|
||||
# Search for matching model pattern
|
||||
for pattern, credit in credit_map.items():
|
||||
if pattern.strip() == model_name:
|
||||
return int(credit)
|
||||
return 1 # Default quota if no match found
|
||||
except (ValueError, AttributeError):
|
||||
return 1 # Return default quota if parsing fails
|
||||
|
||||
|
||||
class HostedOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for hosted OpenAI service
|
||||
@ -181,7 +212,7 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
|
||||
description="Mode for fetching app templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
@ -202,5 +233,7 @@ class HostedServiceConfig(
|
||||
HostedZhipuAIConfig,
|
||||
# moderation
|
||||
HostedModerationConfig,
|
||||
# credit config
|
||||
HostedCreditConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.15.0",
|
||||
default="1.0.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -1,9 +1,19 @@
|
||||
from contextvars import ContextVar
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
|
||||
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
|
||||
|
||||
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
|
||||
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
|
||||
|
||||
@ -1,12 +1,32 @@
|
||||
import mimetypes
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import urllib.parse
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
try:
|
||||
import magic
|
||||
except ImportError:
|
||||
if platform.system() == "Windows":
|
||||
warnings.warn(
|
||||
"To use python-magic guess MIMETYPE, you need to run `pip install python-magic-bin`", stacklevel=2
|
||||
)
|
||||
elif platform.system() == "Darwin":
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to run `brew install libmagic`", stacklevel=2)
|
||||
elif platform.system() == "Linux":
|
||||
warnings.warn(
|
||||
"To use python-magic guess MIMETYPE, you need to run `sudo apt-get install libmagic1`", stacklevel=2
|
||||
)
|
||||
else:
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||
magic = None # type: ignore
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
@ -47,6 +67,13 @@ def guess_file_info_from_response(response: httpx.Response):
|
||||
# If guessing fails, use Content-Type from response headers
|
||||
mimetype = response.headers.get("Content-Type", "application/octet-stream")
|
||||
|
||||
# Use python-magic to guess MIME type if still unknown or generic
|
||||
if mimetype == "application/octet-stream" and magic is not None:
|
||||
try:
|
||||
mimetype = magic.from_buffer(response.content[:1024], mime=True)
|
||||
except magic.MagicException:
|
||||
pass
|
||||
|
||||
extension = os.path.splitext(filename)[1]
|
||||
|
||||
# Ensure filename has an extension
|
||||
|
||||
@ -2,7 +2,7 @@ from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .app.app_import import AppImportApi, AppImportConfirmApi
|
||||
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
|
||||
from .explore.audio import ChatAudioApi, ChatTextApi
|
||||
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from .explore.conversation import (
|
||||
@ -40,6 +40,7 @@ api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
# Import App
|
||||
api.add_resource(AppImportApi, "/apps/imports")
|
||||
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
|
||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
@ -166,4 +167,15 @@ api.add_resource(
|
||||
from .tag import tags
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace
|
||||
from .workspace import (
|
||||
account,
|
||||
agent_providers,
|
||||
endpoint,
|
||||
load_balancing_config,
|
||||
members,
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
@ -2,6 +2,8 @@ from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -54,9 +56,10 @@ class InsertExploreAppListApi(Resource):
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = App.query.filter(App.id == args["app_id"]).first()
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource):
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
@ -110,17 +116,27 @@ class InsertExploreAppApi(Resource):
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
|
||||
with Session(db.engine) as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
app = App.query.filter(App.id == recommended_app.app_id).first()
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
installed_apps = InstalledApp.query.filter(
|
||||
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
|
||||
).all()
|
||||
with Session(db.engine) as session:
|
||||
installed_apps = session.execute(
|
||||
select(InstalledApp).filter(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
for installed_app in installed_apps:
|
||||
db.session.delete(installed_app)
|
||||
|
||||
@ -3,6 +3,8 @@ from typing import Any
|
||||
import flask_restful # type: ignore
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -26,7 +28,16 @@ api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="it
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
|
||||
if resource_model == App:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if resource is None:
|
||||
flask_restful.abort(404, message=f"{resource_model.__name__} not found.")
|
||||
|
||||
@ -5,14 +5,16 @@ from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_fields
|
||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
|
||||
|
||||
@ -88,3 +90,20 @@ class AppImportConfirmApi(Resource):
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class AppImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_fields)
|
||||
def get(self, app_model: App):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@ -22,7 +22,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models import App, AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -79,7 +79,7 @@ class ChatMessageTextApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
@ -98,9 +98,13 @@ class ChatMessageTextApi(Resource):
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
if text_to_speech is None:
|
||||
raise ValueError("TTS is not enabled")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
@ -2,6 +2,7 @@ from datetime import UTC, datetime
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
@ -50,33 +51,37 @@ class AppSite(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
site = Site.query.filter(Site.app_id == app_model.id).one_or_404()
|
||||
with Session(db.engine) as session:
|
||||
site = session.query(Site).filter(Site.app_id == app_model.id).first()
|
||||
|
||||
for attr_name in [
|
||||
"title",
|
||||
"icon_type",
|
||||
"icon",
|
||||
"icon_background",
|
||||
"description",
|
||||
"default_language",
|
||||
"chat_color_theme",
|
||||
"chat_color_theme_inverted",
|
||||
"customize_domain",
|
||||
"copyright",
|
||||
"privacy_policy",
|
||||
"custom_disclaimer",
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
for attr_name in [
|
||||
"title",
|
||||
"icon_type",
|
||||
"icon",
|
||||
"icon_background",
|
||||
"description",
|
||||
"default_language",
|
||||
"chat_color_theme",
|
||||
"chat_color_theme_inverted",
|
||||
"customize_domain",
|
||||
"copyright",
|
||||
"privacy_policy",
|
||||
"custom_disclaimer",
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
session.commit()
|
||||
|
||||
return site
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
@ -96,6 +97,9 @@ class DraftWorkflowApi(Resource):
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
@ -139,6 +143,9 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
||||
@ -160,7 +167,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -178,38 +185,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -228,7 +204,44 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_iteration(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -246,6 +259,9 @@ class DraftWorkflowRunApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
@ -294,13 +310,20 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user
|
||||
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
@ -339,6 +362,9 @@ class PublishedWorkflowApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
||||
|
||||
@ -376,12 +402,17 @@ class DefaultBlockConfigApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
filters = None
|
||||
if args.get("q"):
|
||||
if q:
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
except json.JSONDecodeError:
|
||||
@ -407,6 +438,9 @@ class ConvertToWorkflowApi(Resource):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
|
||||
@ -59,3 +59,9 @@ class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_code_account_deletion_rate_limit_exceeded"
|
||||
description = "Too many account deletion emails have been sent. Please try again in 5 minutes."
|
||||
code = 429
|
||||
|
||||
|
||||
class EmailPasswordResetLimitError(BaseHTTPException):
|
||||
error_code = "email_password_reset_limit"
|
||||
description = "Too many failed password reset attempts. Please try again in 24 hours."
|
||||
code = 429
|
||||
|
||||
@ -3,10 +3,18 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
@ -37,7 +45,8 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = Account.query.filter_by(email=args["email"]).first()
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
@ -62,6 +71,10 @@ class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
@ -70,8 +83,10 @@ class ForgotPasswordCheckApi(Resource):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email")}
|
||||
|
||||
|
||||
@ -104,7 +119,8 @@ class ForgotPasswordResetApi(Resource):
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account = Account.query.filter_by(email=reset_data.get("email")).first()
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
|
||||
if account:
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
@ -125,7 +141,7 @@ class ForgotPasswordResetApi(Resource):
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
except AccountRegisterError as are:
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -5,6 +5,8 @@ from typing import Optional
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
account: Optional[Account] = Account.get_by_openid(provider, user_info.id)
|
||||
|
||||
if not account:
|
||||
account = Account.query.filter_by(email=user_info.email).first()
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@ -4,6 +4,8 @@ import json
|
||||
from flask import request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@ -76,7 +78,10 @@ class DataSourceApi(Resource):
|
||||
def patch(self, binding_id, action):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
# enable binding
|
||||
@ -108,47 +113,53 @@ class DataSourceNotionListApi(Resource):
|
||||
def get(self):
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
exist_page_ids = []
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
documents = Document.query.filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
if dataset_id:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if dataset.data_source_type != "notion_import":
|
||||
raise ValueError("Dataset is not notion type.")
|
||||
|
||||
documents = session.execute(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = session.scalars(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
)
|
||||
).all()
|
||||
if documents:
|
||||
for document in documents:
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
@ -158,14 +169,17 @@ class DataSourceNotionApi(Resource):
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
).first()
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required, enterpris
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
@ -52,12 +53,12 @@ class DatasetListApi(Resource):
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
@ -72,7 +73,9 @@ class DatasetListApi(Resource):
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality":
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
@ -457,7 +460,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -619,8 +622,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
@ -645,6 +647,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@ -7,7 +7,6 @@ from flask import request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy import asc, desc
|
||||
from transformers.hf_argparser import string_to_bool # type: ignore
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -40,6 +39,7 @@ from core.indexing_runner import IndexingRunner
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -150,8 +150,20 @@ class DatasetDocumentListApi(Resource):
|
||||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch = string_to_bool(request.args.get("fetch", default="false"))
|
||||
except (ArgumentTypeError, ValueError, Exception) as e:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
if isinstance(fetch_val, bool):
|
||||
fetch = fetch_val
|
||||
else:
|
||||
if fetch_val.lower() in ("yes", "true", "t", "y", "1"):
|
||||
fetch = True
|
||||
elif fetch_val.lower() in ("no", "false", "f", "n", "0"):
|
||||
fetch = False
|
||||
else:
|
||||
raise ArgumentTypeError(
|
||||
f"Truthy value expected: got {fetch_val} but expected one of yes/no, true/false, t/f, y/n, 1/0 "
|
||||
f"(case insensitive)."
|
||||
)
|
||||
except (ArgumentTypeError, ValueError, Exception):
|
||||
fetch = False
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@ -350,8 +362,7 @@ class DatasetInitApi(Resource):
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -430,6 +441,8 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except PluginDaemonClientSideError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except Exception as e:
|
||||
raise IndexingEstimateError(str(e))
|
||||
|
||||
@ -526,11 +539,12 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
return response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except PluginDaemonClientSideError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except Exception as e:
|
||||
raise IndexingEstimateError(str(e))
|
||||
|
||||
|
||||
@ -168,8 +168,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -217,8 +216,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -267,8 +265,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -368,9 +365,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
if document.doc_form == "qa_model":
|
||||
data = {"content": row[0], "answer": row[1]}
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row[0]}
|
||||
data = {"content": row.iloc[0]}
|
||||
result.append(data)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
@ -437,8 +434,7 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@ -32,7 +32,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
pinned = args["pinned"] == "true"
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@ -50,7 +50,7 @@ class MessageListApi(InstalledAppResource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -2,8 +2,11 @@ import os
|
||||
|
||||
from flask import session
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
@ -42,7 +45,11 @@ class InitValidateAPI(Resource):
|
||||
def get_init_validate_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
return session.get("is_init_validated") or DifySetup.query.first()
|
||||
if session.get("is_init_validated"):
|
||||
return True
|
||||
|
||||
with Session(db.engine) as db_session:
|
||||
return db_session.execute(select(DifySetup)).scalar_one_or_none()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse # type: ignore
|
||||
from configs import dify_config
|
||||
from libs.helper import StrLen, email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
from . import api
|
||||
@ -52,8 +52,9 @@ class SetupApi(Resource):
|
||||
|
||||
def get_setup_status():
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
return DifySetup.query.first()
|
||||
return True
|
||||
return db.session.query(DifySetup).first()
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
api.add_resource(SetupApi, "/setup")
|
||||
|
||||
@ -0,0 +1,56 @@
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission)
|
||||
.filter(
|
||||
TenantPluginPermission.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not permission:
|
||||
# no permission set, allow access for everyone
|
||||
return view(*args, **kwargs)
|
||||
|
||||
if install_required:
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
if debug_required:
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return interceptor
|
||||
|
||||
36
api/controllers/console/workspace/agent_providers.py
Normal file
36
api/controllers/console/workspace/agent_providers.py
Normal file
@ -0,0 +1,36 @@
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
class AgentProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||
|
||||
|
||||
class AgentProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
user = current_user
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||
|
||||
|
||||
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
|
||||
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")
|
||||
205
api/controllers/console/workspace/endpoint.py
Normal file
205
api/controllers/console/workspace/endpoint.py
Normal file
@ -0,0 +1,205 @@
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
class EndpointCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True)
|
||||
parser.add_argument("settings", type=dict, required=True)
|
||||
parser.add_argument("name", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
plugin_unique_identifier = args["plugin_unique_identifier"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class EndpointListForSinglePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
page = args["page"]
|
||||
page_size = args["page_size"]
|
||||
plugin_id = args["plugin_id"]
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class EndpointDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
parser.add_argument("settings", type=dict, required=True)
|
||||
parser.add_argument("name", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
settings = args["settings"]
|
||||
name = args["name"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=name,
|
||||
settings=settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EndpointDisableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("endpoint_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
endpoint_id = args["endpoint_id"]
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
|
||||
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
|
||||
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
|
||||
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
|
||||
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
|
||||
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
|
||||
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")
|
||||
@ -112,10 +112,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
# Load Balancing Config
|
||||
api.add_resource(
|
||||
LoadBalancingCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate",
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
LoadBalancingConfigCredentialsValidateApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate",
|
||||
)
|
||||
|
||||
@ -79,7 +79,7 @@ class ModelProviderValidateApi(Resource):
|
||||
response = {"result": "success" if result else "error"}
|
||||
|
||||
if not result:
|
||||
response["error"] = error
|
||||
response["error"] = error or "Unknown error"
|
||||
|
||||
return response
|
||||
|
||||
@ -125,9 +125,10 @@ class ModelProviderIconApi(Resource):
|
||||
Get model provider icon
|
||||
"""
|
||||
|
||||
def get(self, provider: str, icon_type: str, lang: str):
|
||||
def get(self, tenant_id: str, provider: str, icon_type: str, lang: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
icon, mimetype = model_provider_service.get_model_provider_icon(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
icon_type=icon_type,
|
||||
lang=lang,
|
||||
@ -183,53 +184,17 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
return data
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaSubmitApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=False, nullable=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
result = model_provider_service.free_quota_qualification_verify(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials")
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
|
||||
api.add_resource(
|
||||
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
|
||||
)
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
|
||||
|
||||
api.add_resource(
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type"
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
|
||||
)
|
||||
api.add_resource(ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<path:provider>/checkout-url")
|
||||
api.add_resource(
|
||||
ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderFreeQuotaQualificationVerifyApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify",
|
||||
ModelProviderIconApi,
|
||||
"/workspaces/<string:tenant_id>/model-providers/<path:provider>/<string:icon_type>/<string:lang>",
|
||||
)
|
||||
|
||||
@ -325,7 +325,7 @@ class ModelProviderModelValidateApi(Resource):
|
||||
response = {"result": "success" if result else "error"}
|
||||
|
||||
if not result:
|
||||
response["error"] = error
|
||||
response["error"] = error or ""
|
||||
|
||||
return response
|
||||
|
||||
@ -362,26 +362,26 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
return jsonable_encoder({"data": models})
|
||||
|
||||
|
||||
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models")
|
||||
api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
|
||||
api.add_resource(
|
||||
ModelProviderModelEnableApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/models/enable",
|
||||
"/workspaces/current/model-providers/<path:provider>/models/enable",
|
||||
endpoint="model-provider-model-enable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelDisableApi,
|
||||
"/workspaces/current/model-providers/<string:provider>/models/disable",
|
||||
"/workspaces/current/model-providers/<path:provider>/models/disable",
|
||||
endpoint="model-provider-model-disable",
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials"
|
||||
ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate"
|
||||
ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules"
|
||||
ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
|
||||
)
|
||||
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
|
||||
api.add_resource(DefaultModelApi, "/workspaces/current/default-model")
|
||||
|
||||
475
api/controllers/console/workspace/plugin.py
Normal file
475
api/controllers/console/workspace/plugin.py
Normal file
@ -0,0 +1,475 @@
|
||||
import io
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from libs.login import login_required
|
||||
from models.account import TenantPluginPermission
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {
|
||||
"key": PluginService.get_debugging_key(tenant_id),
|
||||
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
|
||||
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
|
||||
}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
try:
|
||||
plugins = PluginService.list(tenant_id)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
class PluginListInstallationsFromIdsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_ids", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder({"plugins": plugins})
|
||||
|
||||
|
||||
class PluginIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("tenant_id", type=str, required=True, location="args")
|
||||
req.add_argument("filename", type=str, required=True, location="args")
|
||||
args = req.parse_args()
|
||||
|
||||
try:
|
||||
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
class PluginUploadFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
file = request.files["pkg"]
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
try:
|
||||
response = PluginService.upload_pkg(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginUploadFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
parser.add_argument("version", type=str, required=True, location="json")
|
||||
parser.add_argument("package", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginUploadFromBundleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
file = request.files["bundle"]
|
||||
|
||||
# check file size
|
||||
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
|
||||
raise ValueError("File size exceeds the maximum allowed size")
|
||||
|
||||
content = file.read()
|
||||
try:
|
||||
response = PluginService.upload_bundle(tenant_id, content)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginInstallFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginInstallFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
parser.add_argument("version", type=str, required=True, location="json")
|
||||
parser.add_argument("package", type=str, required=True, location="json")
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_github(
|
||||
tenant_id,
|
||||
args["plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginInstallFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(install_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# check if all plugin_unique_identifiers are valid string
|
||||
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||
if not isinstance(plugin_unique_identifier, str):
|
||||
raise ValueError("Invalid plugin unique identifier")
|
||||
|
||||
try:
|
||||
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
return jsonable_encoder(response)
|
||||
|
||||
|
||||
class PluginFetchManifestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"manifest": PluginService.fetch_plugin_manifest(
|
||||
tenant_id, args["plugin_unique_identifier"]
|
||||
).model_dump()
|
||||
}
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginFetchInstallTasksApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=int, required=True, location="args")
|
||||
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginFetchInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def get(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginDeleteInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self, task_id: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self, task_id: str, identifier: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_marketplace(
|
||||
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginUpgradeFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||
parser.add_argument("repo", type=str, required=True, location="json")
|
||||
parser.add_argument("version", type=str, required=True, location="json")
|
||||
parser.add_argument("package", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
PluginService.upgrade_plugin_with_github(
|
||||
tenant_id,
|
||||
args["original_plugin_unique_identifier"],
|
||||
args["new_plugin_unique_identifier"],
|
||||
args["repo"],
|
||||
args["version"],
|
||||
args["package"],
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginUninstallApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def post(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
try:
|
||||
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user = current_user
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("install_permission", type=str, required=True, location="json")
|
||||
req.add_argument("debug_permission", type=str, required=True, location="json")
|
||||
args = req.parse_args()
|
||||
|
||||
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||
|
||||
|
||||
class PluginFetchPermissionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
permission = PluginPermissionService.get_permission(tenant_id)
|
||||
if not permission:
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
}
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"install_permission": permission.install_permission,
|
||||
"debug_permission": permission.debug_permission,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
|
||||
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
|
||||
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
|
||||
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
|
||||
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
|
||||
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
|
||||
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
|
||||
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
|
||||
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
|
||||
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
|
||||
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
|
||||
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
|
||||
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
|
||||
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
|
||||
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
|
||||
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
|
||||
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
|
||||
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
|
||||
|
||||
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
||||
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
||||
@ -25,8 +25,10 @@ class ToolProviderListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument(
|
||||
@ -47,28 +49,43 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.list_builtin_tool_provider_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
|
||||
|
||||
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.delete_builtin_tool_provider(
|
||||
user_id,
|
||||
@ -82,11 +99,13 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
@ -131,11 +150,13 @@ class ToolApiProviderAddApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
@ -168,6 +189,11 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("url", type=str, required=True, nullable=False, location="args")
|
||||
@ -175,8 +201,8 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
return ApiToolManageService.get_api_tool_provider_remote_schema(
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
user_id,
|
||||
tenant_id,
|
||||
args["url"],
|
||||
)
|
||||
|
||||
@ -186,8 +212,10 @@ class ToolApiProviderListToolsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -209,11 +237,13 @@ class ToolApiProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
@ -248,11 +278,13 @@ class ToolApiProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -272,8 +304,10 @@ class ToolApiProviderGetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
@ -293,7 +327,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
|
||||
user = current_user
|
||||
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
|
||||
|
||||
|
||||
class ToolApiProviderSchemaApi(Resource):
|
||||
@ -344,11 +382,13 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
@ -381,11 +421,13 @@ class ToolWorkflowProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
@ -421,11 +463,13 @@ class ToolWorkflowProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_admin_or_owner:
|
||||
user = current_user
|
||||
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
reqparser = reqparse.RequestParser()
|
||||
reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
@ -444,8 +488,10 @@ class ToolWorkflowProviderGetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
@ -476,8 +522,10 @@ class ToolWorkflowProviderListToolApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args")
|
||||
@ -498,8 +546,10 @@ class ToolBuiltinListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
@ -517,8 +567,10 @@ class ToolApiListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
@ -536,8 +588,10 @@ class ToolWorkflowListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(
|
||||
[
|
||||
@ -563,16 +617,18 @@ class ToolLabelsApi(Resource):
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
|
||||
# builtin tool provider
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
|
||||
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
|
||||
)
|
||||
api.add_resource(
|
||||
ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
|
||||
)
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
|
||||
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
|
||||
# api tool provider
|
||||
api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")
|
||||
|
||||
@ -7,6 +7,7 @@ from flask_login import current_user # type: ignore
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from extensions.ext_database import db
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService
|
||||
@ -134,9 +135,13 @@ def setup_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# check setup
|
||||
if dify_config.EDITION == "SELF_HOSTED" and os.environ.get("INIT_PASSWORD") and not DifySetup.query.first():
|
||||
if (
|
||||
dify_config.EDITION == "SELF_HOSTED"
|
||||
and os.environ.get("INIT_PASSWORD")
|
||||
and not db.session.query(DifySetup).first()
|
||||
):
|
||||
raise NotInitValidateError()
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not DifySetup.query.first():
|
||||
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
|
||||
raise NotSetupError()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -6,4 +6,4 @@ bp = Blueprint("files", __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import image_preview, tool_files
|
||||
from . import image_preview, tool_files, upload
|
||||
|
||||
69
api/controllers/files/upload.py
Normal file
69
api/controllers/files/upload.py
Normal file
@ -0,0 +1,69 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, marshal_with # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.files import api
|
||||
from controllers.files.error import UnsupportedFileTypeError
|
||||
from controllers.inner_api.plugin.wraps import get_user
|
||||
from controllers.service_api.app.error import FileTooLargeError
|
||||
from core.file.helpers import verify_plugin_file_signature
|
||||
from fields.file_fields import file_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class PluginUploadFileApi(Resource):
|
||||
@setup_required
|
||||
@marshal_with(file_fields)
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
sign = request.args.get("sign")
|
||||
tenant_id = request.args.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
user_id = request.args.get("user_id")
|
||||
user = get_user(tenant_id, user_id)
|
||||
|
||||
filename = file.filename
|
||||
mimetype = file.mimetype
|
||||
|
||||
if not filename or not mimetype:
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
if not verify_plugin_file_signature(
|
||||
filename=filename,
|
||||
mimetype=mimetype,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
filename=filename,
|
||||
content=file.read(),
|
||||
mimetype=mimetype,
|
||||
user=user,
|
||||
source=None,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
|
||||
@ -5,4 +5,5 @@ from libs.external_api import ExternalApi
|
||||
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from .plugin import plugin
|
||||
from .workspace import workspace
|
||||
|
||||
293
api/controllers/inner_api/plugin/plugin.py
Normal file
293
api/controllers/inner_api/plugin/plugin.py
Normal file
@ -0,0 +1,293 @@
|
||||
from flask_restful import Resource # type: ignore
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
|
||||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.file.helpers import get_signed_file_url_for_plugin
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||
from core.plugin.backwards_invocation.encrypt import PluginEncrypter
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
RequestInvokeEncrypt,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeParameterExtractorNode,
|
||||
RequestInvokeQuestionClassifierNode,
|
||||
RequestInvokeRerank,
|
||||
RequestInvokeSpeech2Text,
|
||||
RequestInvokeSummary,
|
||||
RequestInvokeTextEmbedding,
|
||||
RequestInvokeTool,
|
||||
RequestInvokeTTS,
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from libs.helper import compact_generate_response
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
class PluginInvokeLLMApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_text_embedding(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeRerankApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeRerank)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_rerank(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeTTSApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTTS)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_tts(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeSpeech2TextApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_speech2text(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeModerationApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeModeration)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginModelBackwardsInvocation.invoke_moderation(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeToolApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
PluginToolBackwardsInvocation.invoke_tool(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
tool_type=ToolProviderType.value_of(payload.tool_type),
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
),
|
||||
)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginNodeBackwardsInvocation.invoke_parameter_extractor(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
parameters=payload.parameters,
|
||||
model_config=payload.model,
|
||||
instruction=payload.instruction,
|
||||
query=payload.query,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
BaseBackwardsInvocationResponse(
|
||||
data=PluginNodeBackwardsInvocation.invoke_question_classifier(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
query=payload.query,
|
||||
model_config=payload.model,
|
||||
classes=payload.classes,
|
||||
instruction=payload.instruction,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||
|
||||
|
||||
class PluginInvokeAppApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id=payload.app_id,
|
||||
user_id=user_model.id,
|
||||
tenant_id=tenant_model.id,
|
||||
conversation_id=payload.conversation_id,
|
||||
query=payload.query,
|
||||
stream=payload.response_mode == "streaming",
|
||||
inputs=payload.inputs,
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
|
||||
"""
|
||||
encrypt or decrypt data
|
||||
"""
|
||||
try:
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data=PluginEncrypter.invoke_encrypt(tenant_model, payload)
|
||||
).model_dump()
|
||||
except Exception as e:
|
||||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
class PluginInvokeSummaryApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestInvokeSummary)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
|
||||
try:
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data={
|
||||
"summary": PluginModelBackwardsInvocation.invoke_summary(
|
||||
user_id=user_model.id,
|
||||
tenant=tenant_model,
|
||||
payload=payload,
|
||||
)
|
||||
}
|
||||
).model_dump()
|
||||
except Exception as e:
|
||||
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||
|
||||
|
||||
class PluginUploadFileRequestApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_user_tenant
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||
# generate signed url
|
||||
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
|
||||
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
|
||||
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
|
||||
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
|
||||
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
|
||||
api.add_resource(PluginInvokeAppApi, "/invoke/app")
|
||||
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
||||
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
||||
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
||||
116
api/controllers/inner_api/plugin/wraps.py
Normal file
116
api/controllers/inner_api/plugin/wraps.py
Normal file
@ -0,0 +1,116 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
if user_id == "DEFAULT-USER":
|
||||
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
else:
|
||||
user_model = AccountService.load_user(user_id)
|
||||
if not user_model:
|
||||
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
if not user_model:
|
||||
raise ValueError("user not found")
|
||||
except Exception:
|
||||
raise ValueError("user not found")
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
def get_user_tenant(view: Optional[Callable] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
# fetch json body
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||
parser.add_argument("user_id", type=str, required=True, location="json")
|
||||
|
||||
kwargs = parser.parse_args()
|
||||
|
||||
user_id = kwargs.get("user_id")
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id is required")
|
||||
|
||||
if not user_id:
|
||||
user_id = "DEFAULT-USER"
|
||||
|
||||
del kwargs["tenant_id"]
|
||||
del kwargs["user_id"]
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.filter(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func):
|
||||
def decorated_view(*args, **kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
raise ValueError("invalid json")
|
||||
|
||||
try:
|
||||
payload = payload_type(**data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid payload: {str(e)}")
|
||||
|
||||
kwargs["payload"] = payload
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@ -1,8 +1,10 @@
|
||||
import json
|
||||
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import inner_api_only
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
from models.account import Account
|
||||
from services.account_service import TenantService
|
||||
@ -10,7 +12,7 @@ from services.account_service import TenantService
|
||||
|
||||
class EnterpriseWorkspace(Resource):
|
||||
@setup_required
|
||||
@inner_api_only
|
||||
@enterprise_inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
@ -29,4 +31,34 @@ class EnterpriseWorkspace(Resource):
|
||||
return {"message": "enterprise workspace created."}
|
||||
|
||||
|
||||
class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
resp = {
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"encrypt_public_key": tenant.encrypt_public_key,
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"custom_config": json.loads(tenant.custom_config) if tenant.custom_config else {},
|
||||
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
"updated_at": tenant.updated_at.isoformat() if tenant.updated_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"message": "enterprise workspace created.",
|
||||
"tenant": resp,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EnterpriseWorkspace, "/enterprise/workspace")
|
||||
api.add_resource(EnterpriseWorkspaceNoOwnerEmail, "/enterprise/workspace/ownerless")
|
||||
|
||||
@ -10,7 +10,7 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def inner_api_only(view):
|
||||
def enterprise_inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
@ -18,7 +18,7 @@ def inner_api_only(view):
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
|
||||
abort(401)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
@ -26,7 +26,7 @@ def inner_api_only(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def inner_api_user_auth(view):
|
||||
def enterprise_inner_api_user_auth(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
@ -60,3 +60,19 @@ def inner_api_user_auth(view):
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get("X-Inner-Api-Key")
|
||||
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -7,4 +7,4 @@ api = ExternalApi(bp)
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message, workflow
|
||||
from .dataset import dataset, document, hit_testing, segment
|
||||
from .dataset import dataset, document, hit_testing, segment, upload_file
|
||||
|
||||
@ -31,8 +31,11 @@ class DatasetListApi(DatasetApiResource):
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
|
||||
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, tenant_id, current_user, search, tag_ids, include_all
|
||||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
|
||||
@ -18,6 +18,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
@ -50,6 +51,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("doc_type", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("doc_metadata", type=dict, required=False, nullable=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@ -61,6 +65,28 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
# Validate metadata if provided
|
||||
if args.get("doc_type") or args.get("doc_metadata"):
|
||||
if not args.get("doc_type") or not args.get("doc_metadata"):
|
||||
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
|
||||
|
||||
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise InvalidMetadataError(
|
||||
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
|
||||
)
|
||||
|
||||
if not isinstance(args["doc_metadata"], dict):
|
||||
raise InvalidMetadataError("doc_metadata must be a dictionary")
|
||||
|
||||
# Validate metadata schema based on doc_type
|
||||
if args["doc_type"] != "others":
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
|
||||
for key, value in args["doc_metadata"].items():
|
||||
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
|
||||
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
|
||||
# set to MetaDataConfig
|
||||
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
|
||||
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
@ -107,6 +133,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("doc_type", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("doc_metadata", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@ -115,6 +143,32 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset is not exist.")
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
# Validate metadata if provided
|
||||
if args.get("doc_type") or args.get("doc_metadata"):
|
||||
if not args.get("doc_type") or not args.get("doc_metadata"):
|
||||
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
|
||||
|
||||
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise InvalidMetadataError(
|
||||
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
|
||||
)
|
||||
|
||||
if not isinstance(args["doc_metadata"], dict):
|
||||
raise InvalidMetadataError("doc_metadata must be a dictionary")
|
||||
|
||||
# Validate metadata schema based on doc_type
|
||||
if args["doc_type"] != "others":
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
|
||||
for key, value in args["doc_metadata"].items():
|
||||
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
|
||||
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
|
||||
|
||||
# set to MetaDataConfig
|
||||
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
|
||||
|
||||
if args["text"]:
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
@ -161,6 +215,30 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
args["doc_form"] = "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# Validate metadata if provided
|
||||
if args.get("doc_type") or args.get("doc_metadata"):
|
||||
if not args.get("doc_type") or not args.get("doc_metadata"):
|
||||
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
|
||||
|
||||
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise InvalidMetadataError(
|
||||
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
|
||||
)
|
||||
|
||||
if not isinstance(args["doc_metadata"], dict):
|
||||
raise InvalidMetadataError("doc_metadata must be a dictionary")
|
||||
|
||||
# Validate metadata schema based on doc_type
|
||||
if args["doc_type"] != "others":
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
|
||||
for key, value in args["doc_metadata"].items():
|
||||
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
|
||||
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
|
||||
|
||||
# set to MetaDataConfig
|
||||
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@ -228,6 +306,29 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# Validate metadata if provided
|
||||
if args.get("doc_type") or args.get("doc_metadata"):
|
||||
if not args.get("doc_type") or not args.get("doc_metadata"):
|
||||
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
|
||||
|
||||
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise InvalidMetadataError(
|
||||
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
|
||||
)
|
||||
|
||||
if not isinstance(args["doc_metadata"], dict):
|
||||
raise InvalidMetadataError("doc_metadata must be a dictionary")
|
||||
|
||||
# Validate metadata schema based on doc_type
|
||||
if args["doc_type"] != "others":
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
|
||||
for key, value in args["doc_metadata"].items():
|
||||
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
|
||||
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
|
||||
|
||||
# set to MetaDataConfig
|
||||
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
@ -53,8 +53,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -95,8 +94,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -175,8 +173,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
54
api/controllers/service_api/dataset/upload_file.py
Normal file
54
api/controllers/service_api/dataset/upload_file.py
Normal file
@ -0,0 +1,54 @@
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
@ -195,7 +195,11 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
@ -236,7 +240,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
|
||||
@ -39,7 +39,7 @@ class ConversationListApi(WebApiResource):
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
pinned = args["pinned"] == "true"
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@ -91,7 +91,7 @@ class MessageListApi(WebApiResource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
@ -32,19 +31,16 @@ from core.model_runtime.entities import (
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -62,11 +58,9 @@ class BaseAgentRunner(AppRunner):
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
model_instance: ModelInstance,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@ -79,8 +73,6 @@ class BaseAgentRunner(AppRunner):
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
@ -141,11 +133,10 @@ class BaseAgentRunner(AppRunner):
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
assert tool_entity.entity.description
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.description.llm if tool_entity.description else "",
|
||||
description=tool_entity.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -153,7 +144,7 @@ class BaseAgentRunner(AppRunner):
|
||||
},
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
@ -186,9 +177,11 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
"""
|
||||
assert tool.entity.description
|
||||
|
||||
prompt_tool = PromptMessageTool(
|
||||
name=tool.identity.name if tool.identity else "unknown",
|
||||
description=tool.description.llm if tool.description else "",
|
||||
name=tool.entity.identity.name,
|
||||
description=tool.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@ -234,8 +227,7 @@ class BaseAgentRunner(AppRunner):
|
||||
# save prompt tool
|
||||
prompt_messages_tools.append(prompt_tool)
|
||||
# save tool entity
|
||||
if dataset_tool.identity is not None:
|
||||
tool_instances[dataset_tool.identity.name] = dataset_tool
|
||||
tool_instances[dataset_tool.entity.identity.name] = dataset_tool
|
||||
|
||||
return tool_instances, prompt_messages_tools
|
||||
|
||||
@ -320,24 +312,24 @@ class BaseAgentRunner(AppRunner):
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
tool_name: str | None,
|
||||
tool_input: Union[str, dict, None],
|
||||
thought: str | None,
|
||||
observation: Union[str, dict, None],
|
||||
tool_invoke_meta: Union[str, dict, None],
|
||||
answer: str,
|
||||
answer: str | None,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage | None = None,
|
||||
):
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
queried_thought = (
|
||||
updated_agent_thought = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
)
|
||||
if not queried_thought:
|
||||
raise ValueError(f"Agent thought {agent_thought.id} not found")
|
||||
agent_thought = queried_thought
|
||||
if not updated_agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
agent_thought = updated_agent_thought
|
||||
|
||||
if thought:
|
||||
agent_thought.thought = thought
|
||||
@ -349,39 +341,39 @@ class BaseAgentRunner(AppRunner):
|
||||
if isinstance(tool_input, dict):
|
||||
try:
|
||||
tool_input = json.dumps(tool_input, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_input = json.dumps(tool_input)
|
||||
|
||||
agent_thought.tool_input = tool_input
|
||||
updated_agent_thought.tool_input = tool_input
|
||||
|
||||
if observation:
|
||||
if isinstance(observation, dict):
|
||||
try:
|
||||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
agent_thought.observation = observation
|
||||
updated_agent_thought.observation = observation
|
||||
|
||||
if answer:
|
||||
agent_thought.answer = answer
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
updated_agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
agent_thought.answer_token = llm_usage.completion_tokens
|
||||
agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
agent_thought.tokens = llm_usage.total_tokens
|
||||
agent_thought.total_price = llm_usage.total_price
|
||||
updated_agent_thought.message_token = llm_usage.prompt_tokens
|
||||
updated_agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
updated_agent_thought.message_unit_price = llm_usage.prompt_unit_price
|
||||
updated_agent_thought.answer_token = llm_usage.completion_tokens
|
||||
updated_agent_thought.answer_price_unit = llm_usage.completion_price_unit
|
||||
updated_agent_thought.answer_unit_price = llm_usage.completion_unit_price
|
||||
updated_agent_thought.tokens = llm_usage.total_tokens
|
||||
updated_agent_thought.total_price = llm_usage.total_price
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
labels = updated_agent_thought.tool_labels or {}
|
||||
tools = updated_agent_thought.tool.split(";") if updated_agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
@ -392,42 +384,20 @@ class BaseAgentRunner(AppRunner):
|
||||
else:
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
updated_agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
if tool_invoke_meta is not None:
|
||||
if isinstance(tool_invoke_meta, dict):
|
||||
try:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_invoke_meta = json.dumps(tool_invoke_meta)
|
||||
|
||||
agent_thought.tool_meta_str = tool_invoke_meta
|
||||
updated_agent_thought.tool_meta_str = tool_invoke_meta
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
queried_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not queried_variables:
|
||||
return
|
||||
|
||||
db_variables = queried_variables
|
||||
|
||||
db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize agent history
|
||||
@ -464,11 +434,11 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tool_responses = dict.fromkeys(tools, agent_thought.observation)
|
||||
|
||||
for tool in tools:
|
||||
@ -515,7 +485,11 @@ class BaseAgentRunner(AppRunner):
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
if message.app_model_config:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
else:
|
||||
file_extra_config = None
|
||||
|
||||
if not file_extra_config:
|
||||
return UserPromptMessage(content=message.query)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
@ -18,8 +18,8 @@ from core.model_runtime.entities.message_entities import (
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
|
||||
@ -27,11 +27,11 @@ from models.model import Message
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] | None = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] | None = None
|
||||
_instruction: str = "" # FIXME this must be str for now
|
||||
_query: str | None = None
|
||||
_prompt_messages_tools: list[PromptMessageTool] = []
|
||||
_historic_prompt_messages: list[PromptMessage]
|
||||
_agent_scratchpad: list[AgentScratchpadUnit]
|
||||
_instruction: str
|
||||
_query: str
|
||||
_prompt_messages_tools: Sequence[PromptMessageTool]
|
||||
|
||||
def run(
|
||||
self,
|
||||
@ -42,6 +42,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
|
||||
app_generate_entity = self.application_generate_entity
|
||||
self._repack_app_generate_entity(app_generate_entity)
|
||||
self._init_react_state(query)
|
||||
@ -54,17 +55,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config.agent
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction=instruction or "", inputs=inputs)
|
||||
instruction = app_config.prompt_template.simple_prompt_template or ""
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
@ -116,14 +119,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
if not isinstance(chunks, Generator):
|
||||
raise ValueError("Expected streaming response from LLM")
|
||||
|
||||
# check llm result
|
||||
if not chunks:
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
usage_dict: dict[str, Optional[LLMUsage]] = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
@ -143,25 +139,25 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
action = chunk
|
||||
# detect action
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += json.dumps(chunk.model_dump())
|
||||
scratchpad.action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.action = action
|
||||
else:
|
||||
if scratchpad.agent_response is not None:
|
||||
scratchpad.agent_response += chunk
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought += chunk
|
||||
assert scratchpad.agent_response is not None
|
||||
scratchpad.agent_response += chunk
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought += chunk
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
if scratchpad.thought is not None:
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
if self._agent_scratchpad is not None:
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
assert scratchpad.thought is not None
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if "usage" in usage_dict:
|
||||
@ -172,7 +168,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
@ -256,8 +252,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
)
|
||||
if self.variables_pool is not None and self.db_variables_pool is not None:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -275,7 +269,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
tool_instances: Mapping[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
@ -315,11 +309,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
)
|
||||
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as is not None and self.variables_pool:
|
||||
# FIXME the save_as type is confusing, it should be a string or not
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=str(save_as))
|
||||
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
@ -342,7 +332,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return instruction
|
||||
@ -379,7 +369,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: Optional[list[PromptMessage]] = None
|
||||
self, current_session_messages: list[PromptMessage] | None = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
@ -391,8 +381,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
if not current_scratchpad:
|
||||
if not isinstance(message.content, str | None):
|
||||
raise NotImplementedError("expected str type")
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
@ -411,9 +400,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
if not current_scratchpad:
|
||||
continue
|
||||
if isinstance(message.content, str):
|
||||
if current_scratchpad:
|
||||
assert isinstance(message.content, str)
|
||||
current_scratchpad.observation = message.content
|
||||
else:
|
||||
raise NotImplementedError("expected str type")
|
||||
|
||||
@ -19,8 +19,8 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
"""
|
||||
Organize system prompt
|
||||
"""
|
||||
if not self.app_config.agent:
|
||||
raise ValueError("Agent configuration is not set")
|
||||
assert self.app_config.agent
|
||||
assert self.app_config.agent.prompt
|
||||
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
if not prompt_entity:
|
||||
@ -83,8 +83,10 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
|
||||
@ -1,18 +1,21 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
|
||||
class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_type: ToolProviderType
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
plugin_unique_identifier: str | None = None
|
||||
|
||||
|
||||
class AgentPromptEntity(BaseModel):
|
||||
@ -66,7 +69,7 @@ class AgentEntity(BaseModel):
|
||||
Agent Entity.
|
||||
"""
|
||||
|
||||
class Strategy(Enum):
|
||||
class Strategy(StrEnum):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
@ -78,5 +81,13 @@ class AgentEntity(BaseModel):
|
||||
model: str
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: list[AgentToolEntity] | None = None
|
||||
tools: Optional[list[AgentToolEntity]] = None
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
Agent Invoke Message.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@ -46,18 +46,20 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, LLMUsage] = {"usage": LLMUsage.empty_usage()}
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
def increase_usage(final_llm_usage_dict: dict[str, Optional[LLMUsage]], usage: LLMUsage):
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
@ -107,7 +109,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
if self.stream_tool_call and isinstance(chunks, Generator):
|
||||
if isinstance(chunks, Generator):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
@ -124,7 +126,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
@ -140,7 +142,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
current_llm_usage = chunk.delta.usage
|
||||
|
||||
yield chunk
|
||||
elif not self.stream_tool_call and isinstance(chunks, LLMResult):
|
||||
else:
|
||||
result = chunks
|
||||
# check if there is any tool call
|
||||
if self.check_blocking_tool_calls(result):
|
||||
@ -151,7 +153,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
@ -183,8 +185,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
usage=result.usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"invalid chunks type: {type(chunks)}")
|
||||
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
@ -243,15 +243,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=self.message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
if self.variables_pool:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as
|
||||
)
|
||||
|
||||
for message_file_id in message_files:
|
||||
# publish message file
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
@ -303,8 +300,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
if self.variables_pool and self.db_variables_pool:
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
@ -335,9 +330,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@ -360,7 +353,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
|
||||
@ -383,9 +376,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
|
||||
) -> list[PromptMessage]:
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
|
||||
89
api/core/agent/plugin_entities.py
Normal file
89
api/core/agent/plugin_entities.py
Normal file
@ -0,0 +1,89 @@
|
||||
import enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolProviderIdentity,
|
||||
)
|
||||
|
||||
|
||||
class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
||||
"""
|
||||
Inherits from ToolProviderIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyParameter(PluginParameter):
|
||||
class AgentStrategyParameterType(enum.StrEnum):
|
||||
"""
|
||||
Keep all the types from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class AgentStrategyProviderEntity(BaseModel):
|
||||
identity: AgentStrategyProviderIdentity
|
||||
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
|
||||
|
||||
|
||||
class AgentStrategyIdentity(ToolIdentity):
|
||||
"""
|
||||
Inherits from ToolIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
|
||||
strategies: list[AgentStrategyEntity] = Field(default_factory=list)
|
||||
42
api/core/agent/strategy/base.py
Normal file
42
api/core/agent/strategy/base.py
Normal file
@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
|
||||
|
||||
class BaseAgentStrategy(ABC):
|
||||
"""
|
||||
Agent Strategy
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
"""
|
||||
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
"""
|
||||
Get the parameters for the agent strategy.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
pass
|
||||
59
api/core/agent/strategy/plugin.py
Normal file
59
api/core/agent/strategy/plugin.py
Normal file
@ -0,0 +1,59 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
from core.plugin.manager.agent import PluginAgentManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class PluginAgentStrategy(BaseAgentStrategy):
|
||||
"""
|
||||
Agent Strategy
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
declaration: AgentStrategyEntity
|
||||
|
||||
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
|
||||
self.tenant_id = tenant_id
|
||||
self.declaration = declaration
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Initialize the parameters for the agent strategy.
|
||||
"""
|
||||
for parameter in self.declaration.parameters:
|
||||
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||
return params
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent strategy.
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
|
||||
initialized_params = self.initialize_parameters(params)
|
||||
params = convert_parameters_to_plugin_format(initialized_params)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
agent_provider=self.declaration.identity.provider,
|
||||
agent_strategy=self.declaration.identity.name,
|
||||
agent_params=params,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
@ -4,7 +4,8 @@ from core.app.app_config.entities import EasyUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
@ -63,14 +64,14 @@ class ModelConfigConverter:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
model_mode = LLMMode.CHAT.value
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
@ -2,8 +2,9 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
@ -53,9 +54,18 @@ class ModelConfigManager:
|
||||
raise ValueError("model must be of object type")
|
||||
|
||||
# model.provider
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
if "provider" not in config["model"]:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
if "/" not in config["model"]["provider"]:
|
||||
config["model"]["provider"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
|
||||
)
|
||||
|
||||
if config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
|
||||
@ -37,17 +37,6 @@ logger = logging.getLogger(__name__)
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
_dialogue_count: int
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -65,20 +54,31 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
streaming: Literal[True],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -156,6 +156,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
@ -167,8 +169,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -205,6 +213,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
@ -224,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ def _process_future(
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ""
|
||||
@ -67,7 +67,7 @@ class AppGeneratorTTSPublisher:
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
self.voices = self.model_instance.get_tts_voices(language=language)
|
||||
values = [voice.get("value") for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
|
||||
@ -77,7 +77,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
@ -58,7 +57,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, Any, None]:
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -84,12 +83,12 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, Any, None]:
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -123,4 +122,4 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
||||
@ -17,6 +17,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
@ -219,7 +220,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
tts_publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, features_dict["text_to_speech"].get("voice"), features_dict["text_to_speech"].get("language")
|
||||
)
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
@ -247,7 +250,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
@ -640,6 +643,10 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
session.commit()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._workflow_cycle_manager._handle_agent_log(
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
@ -29,17 +30,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -51,6 +41,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
streaming: Literal[False],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -60,7 +61,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]: ...
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -70,7 +71,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -182,6 +183,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": contextvars.copy_context(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
@ -206,6 +208,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
@ -220,6 +223,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# get conversation and message
|
||||
|
||||
@ -8,18 +8,16 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -64,8 +62,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query,
|
||||
)
|
||||
|
||||
@ -86,8 +84,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
@ -99,8 +97,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
inputs=dict(inputs),
|
||||
query=query or "",
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationError as e:
|
||||
@ -156,9 +154,9 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@ -173,16 +171,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
return
|
||||
|
||||
agent_entity = app_config.agent
|
||||
if not agent_entity:
|
||||
raise ValueError("Agent entity not found")
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
|
||||
)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
assert agent_entity is not None
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
@ -193,16 +182,16 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
if not model_schema or not model_schema.features:
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
@ -243,8 +232,6 @@ class AgentChatAppRunner(AppRunner):
|
||||
user_id=application_generate_entity.user_id,
|
||||
memory=memory,
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
@ -261,73 +248,3 @@ class AgentChatAppRunner(AppRunner):
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True,
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables | None = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
db.session.add(tool_variables)
|
||||
else:
|
||||
# create new tool variables
|
||||
tool_variables = ToolConversationVariables(
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
variables_str="[]",
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(
|
||||
self, db_variables: ToolConversationVariables
|
||||
) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(
|
||||
**{
|
||||
"conversation_id": db_variables.conversation_id,
|
||||
"user_id": db_variables.user_id,
|
||||
"tenant_id": db_variables.tenant_id,
|
||||
"pool": db_variables.variables,
|
||||
}
|
||||
)
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, message: Message
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
|
||||
)
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
|
||||
)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
@ -51,10 +51,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@ -80,13 +79,12 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response( # type: ignore[override]
|
||||
cls,
|
||||
stream_response: Generator[ChatbotAppStreamResponse, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@ -118,4 +116,4 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
yield json.dumps(response_chunk)
|
||||
yield response_chunk
|
||||
|
||||
@ -14,21 +14,15 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
cls,
|
||||
response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]],
|
||||
invoke_from: InvokeFrom,
|
||||
) -> Mapping[str, Any] | Generator[str, None, None]:
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f"data: {chunk}\n\n"
|
||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_full_response(response)
|
||||
|
||||
return _generate_full_response()
|
||||
else:
|
||||
@ -36,12 +30,8 @@ class AppGenerateResponseConverter(ABC):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f"data: {chunk}\n\n"
|
||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_simple_response(response)
|
||||
|
||||
return _generate_simple_response()
|
||||
|
||||
@ -59,14 +49,14 @@ class AppGenerateResponseConverter(ABC):
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.file import File, FileUploadConfig
|
||||
@ -138,3 +139,21 @@ class BaseAppGenerator:
|
||||
if isinstance(value, str):
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def convert_to_event_stream(cls, generator: Union[Mapping, Generator[Mapping | str, None, None]]):
|
||||
"""
|
||||
Convert messages into event stream
|
||||
"""
|
||||
if isinstance(generator, dict):
|
||||
return generator
|
||||
else:
|
||||
|
||||
def gen():
|
||||
for message in generator:
|
||||
if isinstance(message, (Mapping, dict)):
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
else:
|
||||
yield f"event: {message}\n\n"
|
||||
|
||||
return gen()
|
||||
|
||||
@ -2,7 +2,7 @@ import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
||||
@ -115,7 +115,7 @@ class AppQueueManager:
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
result: Optional[Any] = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
|
||||
@ -167,8 +167,7 @@ class AppQueueManager:
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed."
|
||||
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
) -> Generator[str, None, None]: ...
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -58,7 +58,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -67,7 +67,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user