mirror of
https://github.com/langgenius/dify.git
synced 2026-04-22 03:37:44 +08:00
Compare commits
474 Commits
1.9.0
...
test/build
| Author | SHA1 | Date | |
|---|---|---|---|
| bc691464a2 | |||
| d444fa1c70 | |||
| b3a4721815 | |||
| 4637435e42 | |||
| 7a2e951474 | |||
| 1e127df4ab | |||
| ca7794305b | |||
| fd255e81e1 | |||
| 09d31d1263 | |||
| 47dc26f011 | |||
| 123bb3ec08 | |||
| 90f77282e3 | |||
| 5208867ccc | |||
| edc7ccc795 | |||
| c9798f6425 | |||
| 20ecf7f1d0 | |||
| 9dcb780fcb | |||
| 1cb7b09933 | |||
| 2c62a77cf4 | |||
| b9bc48d8dd | |||
| ed234e311b | |||
| 9843fec393 | |||
| aa4cabdeb5 | |||
| eea713b668 | |||
| fc62538a94 | |||
| 7994144df7 | |||
| e153c483b6 | |||
| 422bb4d4bb | |||
| 87a80d7613 | |||
| e91105ca87 | |||
| 37903722fe | |||
| f4c82d0010 | |||
| fe50093c18 | |||
| 4317af1e90 | |||
| 61a0fcc2ea | |||
| f627348b11 | |||
| 87fb9a6b69 | |||
| 97a2e2ec2e | |||
| 68d357d7f6 | |||
| a103ad3ee7 | |||
| f65d5a9761 | |||
| 6e0a5f5bbd | |||
| 22f858152f | |||
| 775d2e14fc | |||
| 744b287e67 | |||
| c0fc5d98f0 | |||
| 08ea79d730 | |||
| f31b821cc0 | |||
| 34be16874f | |||
| e9738b891f | |||
| 829796514a | |||
| ef1db35f80 | |||
| f9c67621ca | |||
| e29e8e3180 | |||
| 7a81e720d4 | |||
| 55600c0eb1 | |||
| 35e41d7d68 | |||
| b610cf9a11 | |||
| c8e9edc024 | |||
| 471cd760d7 | |||
| 7f48c57edf | |||
| 6569801162 | |||
| 9dd83f50a7 | |||
| 59c56b1b0d | |||
| 94cd2de940 | |||
| 3c23375607 | |||
| 56047f638f | |||
| 9c01d3e775 | |||
| c85c87f3da | |||
| eaa02e3d55 | |||
| 0219222a60 | |||
| dba659b220 | |||
| ee6458768e | |||
| ed3d02dc6d | |||
| 95471b1188 | |||
| 6190cfbfd8 | |||
| 11f2f95103 | |||
| 2abbc14703 | |||
| b2b2816ade | |||
| 4461df1bd9 | |||
| f7f6b4a8b0 | |||
| 41be581594 | |||
| 20ad5b7ac2 | |||
| a1c0bd7a1c | |||
| fd7c4e8a6d | |||
| 41e549af14 | |||
| b7360140ee | |||
| c71f7c7613 | |||
| c905c47775 | |||
| 4ca7ba000c | |||
| f260627660 | |||
| 1e9142c213 | |||
| 82890fe38e | |||
| 7dc7c8af98 | |||
| addebc465a | |||
| 5ab315aeaf | |||
| f092bc1912 | |||
| 23b49b8304 | |||
| 9e97248ede | |||
| d532b06310 | |||
| 07a2281730 | |||
| 42385f3ffa | |||
| c597234374 | |||
| 3de73f07c6 | |||
| 0caeaf6e5c | |||
| 3395297c3e | |||
| e60a7c7143 | |||
| 0e62a66cc2 | |||
| ff32dff163 | |||
| 543c5236e7 | |||
| 341b3ae7c9 | |||
| f01907aac2 | |||
| a7c855cab8 | |||
| 29afc0657d | |||
| d9860b8907 | |||
| dc1ae57dc6 | |||
| d6bd2a9bdb | |||
| c9eed67cf6 | |||
| 0ded6303c1 | |||
| b6e0abadab | |||
| 43bcf40f80 | |||
| f06025a342 | |||
| 24fb95b050 | |||
| 49fca63927 | |||
| ce5fe86430 | |||
| 666586b59c | |||
| 8a2851551a | |||
| a2fe4a28c3 | |||
| 417ebd160b | |||
| 82be305680 | |||
| 03002f4971 | |||
| 1e7e8a8988 | |||
| a715d5ac23 | |||
| 398c8117fe | |||
| f45c18ee35 | |||
| 15c1db42dd | |||
| a31c01f8d9 | |||
| 62753cdf13 | |||
| dc7ce125ad | |||
| eabdb09f8e | |||
| fa6d03c979 | |||
| 634fb192ef | |||
| a4b38e7521 | |||
| 8ff6de91b0 | |||
| 7fa0ad3161 | |||
| 53b21eea61 | |||
| 2f3a61b51b | |||
| 8bca7814f4 | |||
| 92c81b1833 | |||
| 44553d412c | |||
| 95ce224df0 | |||
| 8555635967 | |||
| e843fe8aa6 | |||
| b198c9474a | |||
| 4bb00b83d9 | |||
| c91cbf6b97 | |||
| f6ede6f1c1 | |||
| 65976b27fe | |||
| 2d73ee64a3 | |||
| c61c2b0abd | |||
| 40d3332690 | |||
| 8e45753c68 | |||
| 73e217ab0d | |||
| 26ff59172e | |||
| bebb4ffbaa | |||
| 523da66134 | |||
| e1ca7a9bdb | |||
| 9a8cf709ba | |||
| f909040567 | |||
| 845adb664a | |||
| 0c6cae2d59 | |||
| a893ee0ffc | |||
| 82b63cc6e2 | |||
| c327cfa86e | |||
| 82219c1162 | |||
| cfc3f1527a | |||
| caf1a5fbab | |||
| 4a6398fc1f | |||
| 2bcf96565a | |||
| 9a9d6a4a2b | |||
| 05f66fcf0d | |||
| ea8245a91b | |||
| 759a932bb7 | |||
| fb6f05c267 | |||
| ff9b74efeb | |||
| d6e7543ba6 | |||
| e45d5700ec | |||
| 4e6682bd85 | |||
| 32c715c4d0 | |||
| c11cdf7468 | |||
| 6217c96576 | |||
| 977690590e | |||
| fd845c8b6c | |||
| d7d9abb007 | |||
| 9f22b2726b | |||
| f28b519556 | |||
| 762cf91133 | |||
| 9dd3dcff2b | |||
| 34fbcc9457 | |||
| 9cc8ac981b | |||
| 1153dcef69 | |||
| f811471b18 | |||
| 2382229c7d | |||
| f0e739be43 | |||
| 4dccdf9478 | |||
| 4c37d650d3 | |||
| 1b334e6966 | |||
| d463bd6323 | |||
| 8c298b33cd | |||
| dc1a380888 | |||
| 7e9be4d3d9 | |||
| 5579521ffc | |||
| ab1059134d | |||
| fe2ac66a52 | |||
| f87db2652b | |||
| 3f9f02b9e7 | |||
| 578247ffbc | |||
| 9a5f214623 | |||
| 141ca8904a | |||
| 4488c090b2 | |||
| 59c1fde351 | |||
| cf7ff76165 | |||
| ac79691d69 | |||
| 1a37989769 | |||
| 830f891a74 | |||
| 5937a66e22 | |||
| 894e38f713 | |||
| e4b5b0e5fd | |||
| 598dd1f816 | |||
| 35e24d4d14 | |||
| fea2ffb3ba | |||
| 64f55d55a1 | |||
| bfda4ce7e6 | |||
| 4f7cb7cd2a | |||
| 6517323add | |||
| 531a0b755a | |||
| 91bb8ae4d2 | |||
| 8cafc20098 | |||
| 9d5300440c | |||
| 58524d6d2b | |||
| 19cc6ea993 | |||
| d7f0a31e24 | |||
| 312974aa20 | |||
| d19c100166 | |||
| a8ad80c405 | |||
| 650e38e17f | |||
| 24612adf2c | |||
| 06649f6c21 | |||
| 8b61f5e9c4 | |||
| 6432898e7a | |||
| cced33d068 | |||
| bd01af6415 | |||
| 35011b810d | |||
| f295c7532c | |||
| 7065b67d07 | |||
| c0b50ef61d | |||
| 1d8cca4fa2 | |||
| 3474c179e6 | |||
| 433dad7e1a | |||
| be7ee380bc | |||
| cff5de626b | |||
| 4d8b8f9210 | |||
| a16ef7e73c | |||
| c39dae06d4 | |||
| f906e70f6b | |||
| 5139119307 | |||
| 1b537f904a | |||
| 556b631c54 | |||
| 49df9ceaf3 | |||
| 92ec1ac27a | |||
| e74097afdf | |||
| 8ddc4f2292 | |||
| 7b51320346 | |||
| 9e39be0770 | |||
| 3e5e87930c | |||
| 15a5ba67f1 | |||
| 9e3b4dc90d | |||
| 48c42a9fba | |||
| 0b35bc1ede | |||
| 8e01bb40fe | |||
| 9d21772820 | |||
| b745839bdb | |||
| 59ad6e02ce | |||
| a3b33cbe28 | |||
| 7b8540281a | |||
| 0a6b78f883 | |||
| 56ee8f7d64 | |||
| 3cfcd32876 | |||
| 06dcb55a9d | |||
| ec6cafd7aa | |||
| 6e9858960d | |||
| 150a8276b9 | |||
| c6a90d4bb3 | |||
| c71fd7113c | |||
| 5fc104a992 | |||
| d1de3cfb94 | |||
| 44d36f2460 | |||
| 9088f151d9 | |||
| c692962650 | |||
| f0a60a9000 | |||
| 2f50f3fd4b | |||
| 24cd7bbc62 | |||
| d299e75e1b | |||
| f86b6658c9 | |||
| 0a56d65581 | |||
| dfc03bac9f | |||
| 81e1376e08 | |||
| f50c85d536 | |||
| 5830c69694 | |||
| 0173496a77 | |||
| 30c5b47699 | |||
| e3191d4e91 | |||
| a9b3539b90 | |||
| 5217017e69 | |||
| bd5df5cf1c | |||
| 456dbfe7d7 | |||
| 586f210d6e | |||
| 275a0f9ddd | |||
| cbf2ba6cec | |||
| 1bd621f819 | |||
| bb6a331490 | |||
| 3922ad876f | |||
| fdb53fdeb1 | |||
| 3fb5a7bff1 | |||
| 6157c67cfe | |||
| fbc745764a | |||
| 78f09801b5 | |||
| d0dd81cf84 | |||
| 65b832c46c | |||
| a90b60c36f | |||
| 94a07706ec | |||
| ab2eacb6c1 | |||
| aead192743 | |||
| c1e8584b97 | |||
| 8a2b208299 | |||
| 2b6882bd97 | |||
| aa51662d98 | |||
| 3068526797 | |||
| 298d8c2d88 | |||
| 294e01a8c1 | |||
| 3a5aa4587c | |||
| cf1778e696 | |||
| 54db4c176a | |||
| 5d3e8a31d0 | |||
| 885dff82e3 | |||
| 3c4aa24198 | |||
| 33b0814323 | |||
| 45ae511036 | |||
| 0fa063c640 | |||
| 40d35304ea | |||
| 89821d66bb | |||
| 09d84e900c | |||
| a8746bff30 | |||
| c4d8bf0ce9 | |||
| 9cca605bac | |||
| dbd23f91e5 | |||
| 9387cc088c | |||
| 11f7a89e25 | |||
| 654d522b31 | |||
| 31e6ef77a6 | |||
| e56c847210 | |||
| e00172199a | |||
| 04f47836d8 | |||
| faaca822e4 | |||
| dc0f053925 | |||
| 517726da3a | |||
| 1d6c03eddf | |||
| fdfccd1205 | |||
| b30e7ced0a | |||
| 11770439be | |||
| d89c5f7146 | |||
| 4a475bf1cd | |||
| 10be9cfbbf | |||
| c20e0ad90d | |||
| 22f64d60bb | |||
| 7b7d332239 | |||
| b1d189324a | |||
| 00fb468f2e | |||
| bbbb6e04cb | |||
| f5161d9add | |||
| 787251f00e | |||
| cfe21f0826 | |||
| 196f691865 | |||
| 7a5bb1cfac | |||
| b80d55b764 | |||
| dd71625f52 | |||
| 19936d23d1 | |||
| decf0f3da0 | |||
| 7242a67f84 | |||
| c4884eb669 | |||
| d49f3327e4 | |||
| 633e68a2f7 | |||
| 809f48f733 | |||
| 578b1b45ea | |||
| 86c3c58e64 | |||
| 8d803a26eb | |||
| aa3129c2a9 | |||
| 97c924fe29 | |||
| 591c463e4b | |||
| e1691fddaa | |||
| b4d4351203 | |||
| f7b1348623 | |||
| 2619c7553a | |||
| f79d8baf63 | |||
| bbdcbac544 | |||
| d552680e72 | |||
| df43c6ab8a | |||
| cd47a47c3b | |||
| e5d4235f1b | |||
| f60aa36fa0 | |||
| b2bcb6d21a | |||
| b6cea71023 | |||
| 6462328620 | |||
| fd86cadf67 | |||
| c43c72c1a3 | |||
| d77c2e4d17 | |||
| 1a7898dff1 | |||
| af662b100b | |||
| 595df172a8 | |||
| 70bc5ca7f4 | |||
| 30617feff8 | |||
| 756864c85b | |||
| c8c94ef870 | |||
| 10d51ada59 | |||
| 00f3a53f1c | |||
| d2f0551170 | |||
| cba2b9b2ad | |||
| 029d5d36ac | |||
| 8d897153a5 | |||
| 2e914808ea | |||
| d00a72a435 | |||
| 36580221aa | |||
| e686cc9eab | |||
| 66196459d5 | |||
| a5387b304e | |||
| beb1448441 | |||
| 272102c06d | |||
| 36406cd62f | |||
| 87c41c88a3 | |||
| 095c56a646 | |||
| 244c132656 | |||
| 043ec46c33 | |||
| 0e4f19eee0 | |||
| ff34969f21 | |||
| 9a7245e1df | |||
| 4906eeac18 | |||
| 4da93ba579 | |||
| 319ecdd312 | |||
| 0c1ec35244 | |||
| 46375aacdb | |||
| e6d4331994 | |||
| 2a0abc51b1 | |||
| 3bb67885ef | |||
| e682749d03 | |||
| 9b83b0aadd | |||
| 0cac330bc2 | |||
| fb8114792a | |||
| eab6f65409 | |||
| 915023b809 | |||
| f104839672 | |||
| 6841a09667 | |||
| e937c8c72e | |||
| 960bb8a9b4 | |||
| 9b36059292 | |||
| a4acc64afd | |||
| 25c69ac540 | |||
| 96a0b9991e | |||
| 2913d17fe2 | |||
| d9e45a1abe | |||
| 24b4289d6c | |||
| fb6ccccc3d | |||
| 8b74ae683a | |||
| dd08957381 | |||
| 407323f817 |
@ -1,4 +1,4 @@
|
|||||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
|
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
|
||||||
|
|
||||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||||
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||||
|
|||||||
@ -11,7 +11,7 @@
|
|||||||
"nodeGypDependencies": true,
|
"nodeGypDependencies": true,
|
||||||
"version": "lts"
|
"version": "lts"
|
||||||
},
|
},
|
||||||
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
|
"ghcr.io/devcontainers-extra/features/npm-package:1": {
|
||||||
"package": "typescript",
|
"package": "typescript",
|
||||||
"version": "latest"
|
"version": "latest"
|
||||||
},
|
},
|
||||||
|
|||||||
@ -6,7 +6,7 @@ cd web && pnpm install
|
|||||||
pipx install uv
|
pipx install uv
|
||||||
|
|
||||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||||
|
|||||||
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -1,5 +1,8 @@
|
|||||||
blank_issues_enabled: false
|
blank_issues_enabled: false
|
||||||
contact_links:
|
contact_links:
|
||||||
|
- name: "\U0001F510 Security Vulnerabilities"
|
||||||
|
url: "https://github.com/langgenius/dify/security/advisories/new"
|
||||||
|
about: Report security vulnerabilities through GitHub Security Advisories to ensure responsible disclosure. 💡 Please do not report security vulnerabilities in public issues.
|
||||||
- name: "\U0001F4A1 Model Providers & Plugins"
|
- name: "\U0001F4A1 Model Providers & Plugins"
|
||||||
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
|
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
|
||||||
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
|
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
|
||||||
|
|||||||
30
.github/workflows/api-tests.yml
vendored
30
.github/workflows/api-tests.yml
vendored
@ -39,25 +39,11 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: uv sync --project api --dev
|
run: uv sync --project api --dev
|
||||||
|
|
||||||
- name: Run Unit tests
|
|
||||||
run: |
|
|
||||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
|
||||||
|
|
||||||
- name: Run pyrefly check
|
- name: Run pyrefly check
|
||||||
run: |
|
run: |
|
||||||
cd api
|
cd api
|
||||||
uv add --dev pyrefly
|
uv add --dev pyrefly
|
||||||
uv run pyrefly check || true
|
uv run pyrefly check || true
|
||||||
- name: Coverage Summary
|
|
||||||
run: |
|
|
||||||
set -x
|
|
||||||
# Extract coverage percentage and create a summary
|
|
||||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
|
||||||
|
|
||||||
# Create a detailed coverage summary
|
|
||||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
|
||||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
|
||||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
|
||||||
|
|
||||||
- name: Run dify config tests
|
- name: Run dify config tests
|
||||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||||
@ -93,3 +79,19 @@ jobs:
|
|||||||
|
|
||||||
- name: Run TestContainers
|
- name: Run TestContainers
|
||||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||||
|
|
||||||
|
- name: Run Unit tests
|
||||||
|
run: |
|
||||||
|
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||||
|
|
||||||
|
- name: Coverage Summary
|
||||||
|
run: |
|
||||||
|
set -x
|
||||||
|
# Extract coverage percentage and create a summary
|
||||||
|
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||||
|
|
||||||
|
# Create a detailed coverage summary
|
||||||
|
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||||
|
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||||
|
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||||
|
|
||||||
|
|||||||
6
.github/workflows/autofix.yml
vendored
6
.github/workflows/autofix.yml
vendored
@ -15,10 +15,12 @@ jobs:
|
|||||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||||
- uses: astral-sh/setup-uv@v6
|
- uses: astral-sh/setup-uv@v6
|
||||||
with:
|
with:
|
||||||
python-version: "3.12"
|
python-version: "3.11"
|
||||||
- run: |
|
- run: |
|
||||||
cd api
|
cd api
|
||||||
uv sync --dev
|
uv sync --dev
|
||||||
|
# fmt first to avoid line too long
|
||||||
|
uv run ruff format ..
|
||||||
# Fix lint errors
|
# Fix lint errors
|
||||||
uv run ruff check --fix .
|
uv run ruff check --fix .
|
||||||
# Format code
|
# Format code
|
||||||
@ -28,6 +30,8 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||||
|
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
||||||
|
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
||||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||||
cat > /tmp/optional-rule.yml << 'EOF'
|
cat > /tmp/optional-rule.yml << 'EOF'
|
||||||
id: convert-optional-to-union
|
id: convert-optional-to-union
|
||||||
|
|||||||
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
@ -4,10 +4,10 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
- "deploy/dev"
|
- "deploy/**"
|
||||||
- "deploy/enterprise"
|
|
||||||
- "build/**"
|
- "build/**"
|
||||||
- "release/e-*"
|
- "release/e-*"
|
||||||
|
- "hotfix/**"
|
||||||
tags:
|
tags:
|
||||||
- "*"
|
- "*"
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
|||||||
- name: Deploy to server
|
- name: Deploy to server
|
||||||
uses: appleboy/ssh-action@v0.1.8
|
uses: appleboy/ssh-action@v0.1.8
|
||||||
with:
|
with:
|
||||||
host: ${{ secrets.RAG_SSH_HOST }}
|
host: ${{ secrets.SSH_HOST }}
|
||||||
username: ${{ secrets.SSH_USER }}
|
username: ${{ secrets.SSH_USER }}
|
||||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
script: |
|
script: |
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
name: Deploy RAG Dev
|
name: Deploy Trigger Dev
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
@ -7,7 +7,7 @@ on:
|
|||||||
workflow_run:
|
workflow_run:
|
||||||
workflows: ["Build and Push API & Web"]
|
workflows: ["Build and Push API & Web"]
|
||||||
branches:
|
branches:
|
||||||
- "deploy/rag-dev"
|
- "deploy/trigger-dev"
|
||||||
types:
|
types:
|
||||||
- completed
|
- completed
|
||||||
|
|
||||||
@ -16,12 +16,12 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: |
|
if: |
|
||||||
github.event.workflow_run.conclusion == 'success' &&
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
|
||||||
steps:
|
steps:
|
||||||
- name: Deploy to server
|
- name: Deploy to server
|
||||||
uses: appleboy/ssh-action@v0.1.8
|
uses: appleboy/ssh-action@v0.1.8
|
||||||
with:
|
with:
|
||||||
host: ${{ secrets.RAG_SSH_HOST }}
|
host: ${{ secrets.TRIGGER_SSH_HOST }}
|
||||||
username: ${{ secrets.SSH_USER }}
|
username: ${{ secrets.SSH_USER }}
|
||||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||||
script: |
|
script: |
|
||||||
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@ -1,6 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
||||||
|
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||||
@ -13,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
|
|||||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||||
|
|
||||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||||
|
|||||||
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@ -103,6 +103,11 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pnpm run lint
|
pnpm run lint
|
||||||
|
|
||||||
|
- name: Web type check
|
||||||
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
working-directory: ./web
|
||||||
|
run: pnpm run type-check
|
||||||
|
|
||||||
docker-compose-template:
|
docker-compose-template:
|
||||||
name: Docker Compose Template
|
name: Docker Compose Template
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -97,6 +97,7 @@ __pypackages__/
|
|||||||
|
|
||||||
# Celery stuff
|
# Celery stuff
|
||||||
celerybeat-schedule
|
celerybeat-schedule
|
||||||
|
celerybeat-schedule.db
|
||||||
celerybeat.pid
|
celerybeat.pid
|
||||||
|
|
||||||
# SageMath parsed files
|
# SageMath parsed files
|
||||||
|
|||||||
9
.vscode/launch.json.template
vendored
9
.vscode/launch.json.template
vendored
@ -8,8 +8,7 @@
|
|||||||
"module": "flask",
|
"module": "flask",
|
||||||
"env": {
|
"env": {
|
||||||
"FLASK_APP": "app.py",
|
"FLASK_APP": "app.py",
|
||||||
"FLASK_ENV": "development",
|
"FLASK_ENV": "development"
|
||||||
"GEVENT_SUPPORT": "True"
|
|
||||||
},
|
},
|
||||||
"args": [
|
"args": [
|
||||||
"run",
|
"run",
|
||||||
@ -28,9 +27,7 @@
|
|||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"module": "celery",
|
"module": "celery",
|
||||||
"env": {
|
"env": {},
|
||||||
"GEVENT_SUPPORT": "True"
|
|
||||||
},
|
|
||||||
"args": [
|
"args": [
|
||||||
"-A",
|
"-A",
|
||||||
"app.celery",
|
"app.celery",
|
||||||
@ -40,7 +37,7 @@
|
|||||||
"-c",
|
"-c",
|
||||||
"1",
|
"1",
|
||||||
"-Q",
|
"-Q",
|
||||||
"dataset,generation,mail,ops_trace",
|
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
|
||||||
"--loglevel",
|
"--loglevel",
|
||||||
"INFO"
|
"INFO"
|
||||||
],
|
],
|
||||||
|
|||||||
89
AGENTS.md
89
AGENTS.md
@ -4,84 +4,51 @@
|
|||||||
|
|
||||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||||
|
|
||||||
The codebase consists of:
|
The codebase is split into:
|
||||||
|
|
||||||
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
|
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
|
||||||
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
|
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
|
||||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||||
|
|
||||||
## Development Commands
|
## Backend Workflow
|
||||||
|
|
||||||
### Backend (API)
|
- Run backend CLI commands through `uv run --project api <command>`.
|
||||||
|
|
||||||
All Python commands must be prefixed with `uv run --project api`:
|
- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||||
|
|
||||||
```bash
|
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
|
||||||
# Start development servers
|
|
||||||
./dev/start-api # Start API server
|
|
||||||
./dev/start-worker # Start Celery worker
|
|
||||||
|
|
||||||
# Run tests
|
- Integration tests are CI-only and are not expected to run in the local environment.
|
||||||
uv run --project api pytest # Run all tests
|
|
||||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
|
||||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
|
||||||
|
|
||||||
# Code quality
|
## Frontend Workflow
|
||||||
./dev/reformat # Run all formatters and linters
|
|
||||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
|
||||||
uv run --project api ruff format ./ # Format code
|
|
||||||
uv run --directory api basedpyright # Type checking
|
|
||||||
```
|
|
||||||
|
|
||||||
### Frontend (Web)
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd web
|
cd web
|
||||||
pnpm lint # Run ESLint
|
pnpm lint
|
||||||
pnpm eslint-fix # Fix ESLint issues
|
pnpm lint:fix
|
||||||
pnpm test # Run Jest tests
|
pnpm test
|
||||||
```
|
```
|
||||||
|
|
||||||
## Testing Guidelines
|
## Testing & Quality Practices
|
||||||
|
|
||||||
### Backend Testing
|
- Follow TDD: red → green → refactor.
|
||||||
|
- Use `pytest` for backend tests with Arrange-Act-Assert structure.
|
||||||
|
- Enforce strong typing; avoid `Any` and prefer explicit type annotations.
|
||||||
|
- Write self-documenting code; only add comments that explain intent.
|
||||||
|
|
||||||
- Use `pytest` for all backend tests
|
## Language Style
|
||||||
- Write tests first (TDD approach)
|
|
||||||
- Test structure: Arrange-Act-Assert
|
|
||||||
|
|
||||||
## Code Style Requirements
|
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
||||||
|
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
|
||||||
|
|
||||||
### Python
|
## General Practices
|
||||||
|
|
||||||
- Use type hints for all functions and class attributes
|
- Prefer editing existing files; add new documentation only when requested.
|
||||||
- No `Any` types unless absolutely necessary
|
- Inject dependencies through constructors and preserve clean architecture boundaries.
|
||||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
- Handle errors with domain-specific exceptions at the correct layer.
|
||||||
|
|
||||||
### TypeScript/JavaScript
|
## Project Conventions
|
||||||
|
|
||||||
- Strict TypeScript configuration
|
- Backend architecture adheres to DDD and Clean Architecture principles.
|
||||||
- ESLint with Prettier integration
|
- Async work runs through Celery with Redis as the broker.
|
||||||
- Avoid `any` type
|
- Frontend user-facing strings must use `web/i18n/en-US/`; avoid hardcoded text.
|
||||||
|
|
||||||
## Important Notes
|
|
||||||
|
|
||||||
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
|
|
||||||
- **Comments**: Only write meaningful comments that explain "why", not "what"
|
|
||||||
- **File Creation**: Always prefer editing existing files over creating new ones
|
|
||||||
- **Documentation**: Don't create documentation files unless explicitly requested
|
|
||||||
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
|
|
||||||
|
|
||||||
## Common Development Tasks
|
|
||||||
|
|
||||||
### Adding a New API Endpoint
|
|
||||||
|
|
||||||
1. Create controller in `/api/controllers/`
|
|
||||||
1. Add service logic in `/api/services/`
|
|
||||||
1. Update routes in controller's `__init__.py`
|
|
||||||
1. Write tests in `/api/tests/`
|
|
||||||
|
|
||||||
## Project-Specific Conventions
|
|
||||||
|
|
||||||
- All async tasks use Celery with Redis as broker
|
|
||||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
|
||||||
|
|||||||
1
Makefile
1
Makefile
@ -26,7 +26,6 @@ prepare-web:
|
|||||||
@echo "🌐 Setting up web environment..."
|
@echo "🌐 Setting up web environment..."
|
||||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||||
@cd web && pnpm install
|
@cd web && pnpm install
|
||||||
@cd web && pnpm build
|
|
||||||
@echo "✅ Web environment prepared (not started)"
|
@echo "✅ Web environment prepared (not started)"
|
||||||
|
|
||||||
# Step 3: Prepare API environment
|
# Step 3: Prepare API environment
|
||||||
|
|||||||
44
README.md
44
README.md
@ -40,18 +40,18 @@
|
|||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||||
<a href="./README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||||
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||||
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||||
<a href="./README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||||
<a href="./README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||||
<a href="./README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||||
<a href="./README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
||||||
@ -63,7 +63,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
|
|||||||
> - CPU >= 2 Core
|
> - CPU >= 2 Core
|
||||||
> - RAM >= 4 GiB
|
> - RAM >= 4 GiB
|
||||||
|
|
||||||
</br>
|
<br/>
|
||||||
|
|
||||||
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||||
|
|
||||||
@ -109,15 +109,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
|||||||
|
|
||||||
## Using Dify
|
## Using Dify
|
||||||
|
|
||||||
- **Cloud </br>**
|
- **Cloud <br/>**
|
||||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||||
|
|
||||||
- **Self-hosting Dify Community Edition</br>**
|
- **Self-hosting Dify Community Edition<br/>**
|
||||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||||
|
|
||||||
- **Dify for enterprise / organizations</br>**
|
- **Dify for enterprise / organizations<br/>**
|
||||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
|
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
|
||||||
|
|
||||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||||
|
|
||||||
@ -129,8 +129,18 @@ Star Dify on GitHub and be instantly notified of new releases.
|
|||||||
|
|
||||||
## Advanced Setup
|
## Advanced Setup
|
||||||
|
|
||||||
|
### Custom configurations
|
||||||
|
|
||||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||||
|
|
||||||
|
### Metrics Monitoring with Grafana
|
||||||
|
|
||||||
|
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
||||||
|
|
||||||
|
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
|
||||||
|
|
||||||
|
### Deployment with Kubernetes
|
||||||
|
|
||||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||||
|
|
||||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||||
|
|||||||
@ -156,6 +156,9 @@ SUPABASE_URL=your-server-url
|
|||||||
# CORS configuration
|
# CORS configuration
|
||||||
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||||
|
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
|
||||||
|
# Provide the registrable domain (e.g. example.com); leading dots are optional.
|
||||||
|
COOKIE_DOMAIN=
|
||||||
|
|
||||||
# Vector database configuration
|
# Vector database configuration
|
||||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||||
@ -343,6 +346,15 @@ OCEANBASE_VECTOR_DATABASE=test
|
|||||||
OCEANBASE_MEMORY_LIMIT=6G
|
OCEANBASE_MEMORY_LIMIT=6G
|
||||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||||
|
|
||||||
|
# AlibabaCloud MySQL Vector configuration
|
||||||
|
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||||
|
ALIBABACLOUD_MYSQL_PORT=3306
|
||||||
|
ALIBABACLOUD_MYSQL_USER=root
|
||||||
|
ALIBABACLOUD_MYSQL_PASSWORD=root
|
||||||
|
ALIBABACLOUD_MYSQL_DATABASE=dify
|
||||||
|
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
|
||||||
|
ALIBABACLOUD_MYSQL_HNSW_M=6
|
||||||
|
|
||||||
# openGauss configuration
|
# openGauss configuration
|
||||||
OPENGAUSS_HOST=127.0.0.1
|
OPENGAUSS_HOST=127.0.0.1
|
||||||
OPENGAUSS_PORT=6600
|
OPENGAUSS_PORT=6600
|
||||||
@ -359,6 +371,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
|||||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||||
|
|
||||||
|
# Comma-separated list of file extensions blocked from upload for security reasons.
|
||||||
|
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
|
||||||
|
# Empty by default to allow all file types.
|
||||||
|
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
|
||||||
|
UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||||
|
|
||||||
# Model configuration
|
# Model configuration
|
||||||
MULTIMODAL_SEND_FORMAT=base64
|
MULTIMODAL_SEND_FORMAT=base64
|
||||||
PROMPT_GENERATION_MAX_TOKENS=512
|
PROMPT_GENERATION_MAX_TOKENS=512
|
||||||
@ -408,6 +426,9 @@ SSRF_DEFAULT_TIME_OUT=5
|
|||||||
SSRF_DEFAULT_CONNECT_TIME_OUT=5
|
SSRF_DEFAULT_CONNECT_TIME_OUT=5
|
||||||
SSRF_DEFAULT_READ_TIME_OUT=5
|
SSRF_DEFAULT_READ_TIME_OUT=5
|
||||||
SSRF_DEFAULT_WRITE_TIME_OUT=5
|
SSRF_DEFAULT_WRITE_TIME_OUT=5
|
||||||
|
SSRF_POOL_MAX_CONNECTIONS=100
|
||||||
|
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||||
|
SSRF_POOL_KEEPALIVE_EXPIRY=5.0
|
||||||
|
|
||||||
BATCH_UPLOAD_LIMIT=10
|
BATCH_UPLOAD_LIMIT=10
|
||||||
KEYWORD_DATA_SOURCE_TYPE=database
|
KEYWORD_DATA_SOURCE_TYPE=database
|
||||||
@ -418,10 +439,17 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
|
|||||||
# CODE EXECUTION CONFIGURATION
|
# CODE EXECUTION CONFIGURATION
|
||||||
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
||||||
CODE_EXECUTION_API_KEY=dify-sandbox
|
CODE_EXECUTION_API_KEY=dify-sandbox
|
||||||
|
CODE_EXECUTION_SSL_VERIFY=True
|
||||||
|
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
|
||||||
|
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||||
|
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||||
|
CODE_EXECUTION_CONNECT_TIMEOUT=10
|
||||||
|
CODE_EXECUTION_READ_TIMEOUT=60
|
||||||
|
CODE_EXECUTION_WRITE_TIMEOUT=10
|
||||||
CODE_MAX_NUMBER=9223372036854775807
|
CODE_MAX_NUMBER=9223372036854775807
|
||||||
CODE_MIN_NUMBER=-9223372036854775808
|
CODE_MIN_NUMBER=-9223372036854775808
|
||||||
CODE_MAX_STRING_LENGTH=80000
|
CODE_MAX_STRING_LENGTH=400000
|
||||||
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
|
TEMPLATE_TRANSFORM_MAX_LENGTH=400000
|
||||||
CODE_MAX_STRING_ARRAY_LENGTH=30
|
CODE_MAX_STRING_ARRAY_LENGTH=30
|
||||||
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
||||||
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||||
@ -461,7 +489,6 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
|||||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||||
WORKFLOW_CALL_MAX_DEPTH=5
|
WORKFLOW_CALL_MAX_DEPTH=5
|
||||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
|
||||||
MAX_VARIABLE_SIZE=204800
|
MAX_VARIABLE_SIZE=204800
|
||||||
|
|
||||||
# GraphEngine Worker Pool Configuration
|
# GraphEngine Worker Pool Configuration
|
||||||
@ -587,3 +614,9 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
|||||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
||||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||||
|
|
||||||
|
# Tenant isolated task queue configuration
|
||||||
|
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||||
|
|
||||||
|
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||||
|
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||||
|
|||||||
@ -81,7 +81,6 @@ ignore = [
|
|||||||
"SIM113", # enumerate-for-loop
|
"SIM113", # enumerate-for-loop
|
||||||
"SIM117", # multiple-with-statements
|
"SIM117", # multiple-with-statements
|
||||||
"SIM210", # if-expr-with-true-false
|
"SIM210", # if-expr-with-true-false
|
||||||
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[lint.per-file-ignores]
|
[lint.per-file-ignores]
|
||||||
|
|||||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@ -54,7 +54,7 @@
|
|||||||
"--loglevel",
|
"--loglevel",
|
||||||
"DEBUG",
|
"DEBUG",
|
||||||
"-Q",
|
"-Q",
|
||||||
"dataset,generation,mail,ops_trace,app_deletion"
|
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@ -15,7 +15,11 @@ FROM base AS packages
|
|||||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
&& apt-get install -y --no-install-recommends \
|
||||||
|
# basic environment
|
||||||
|
g++ \
|
||||||
|
# for building gmpy2
|
||||||
|
libmpfr-dev libmpc-dev
|
||||||
|
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
COPY pyproject.toml uv.lock ./
|
COPY pyproject.toml uv.lock ./
|
||||||
@ -49,7 +53,9 @@ RUN \
|
|||||||
# Install dependencies
|
# Install dependencies
|
||||||
&& apt-get install -y --no-install-recommends \
|
&& apt-get install -y --no-install-recommends \
|
||||||
# basic environment
|
# basic environment
|
||||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
curl nodejs \
|
||||||
|
# for gmpy2 \
|
||||||
|
libgmp-dev libmpfr-dev libmpc-dev \
|
||||||
# For Security
|
# For Security
|
||||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||||
# install fonts to support the use of tools like pypdfium2
|
# install fonts to support the use of tools like pypdfium2
|
||||||
|
|||||||
@ -80,10 +80,10 @@
|
|||||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||||
```
|
```
|
||||||
|
|
||||||
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
|
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run celery -A app.celery beat
|
uv run celery -A app.celery beat
|
||||||
|
|||||||
19
api/app.py
19
api/app.py
@ -13,23 +13,12 @@ if is_db_command():
|
|||||||
|
|
||||||
app = create_migrations_app()
|
app = create_migrations_app()
|
||||||
else:
|
else:
|
||||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||||
# so we need to disable gevent in debug mode.
|
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
|
||||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
|
||||||
# from gevent import monkey
|
|
||||||
#
|
#
|
||||||
# # gevent
|
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
|
||||||
# monkey.patch_all()
|
|
||||||
#
|
#
|
||||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
|
||||||
#
|
|
||||||
# # grpc gevent
|
|
||||||
# grpc_gevent.init_gevent()
|
|
||||||
|
|
||||||
# import psycogreen.gevent # type: ignore
|
|
||||||
#
|
|
||||||
# psycogreen.gevent.patch_psycopg()
|
|
||||||
|
|
||||||
from app_factory import create_app
|
from app_factory import create_app
|
||||||
|
|
||||||
|
|||||||
188
api/commands.py
188
api/commands.py
@ -10,6 +10,7 @@ from flask import current_app
|
|||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
@ -61,31 +62,30 @@ def reset_password(email, new_password, password_confirm):
|
|||||||
if str(new_password).strip() != str(password_confirm).strip():
|
if str(new_password).strip() != str(password_confirm).strip():
|
||||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||||
return
|
return
|
||||||
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
|
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||||
|
|
||||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
if not account:
|
||||||
|
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||||
|
return
|
||||||
|
|
||||||
if not account:
|
try:
|
||||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
valid_password(new_password)
|
||||||
return
|
except:
|
||||||
|
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
# generate password salt
|
||||||
valid_password(new_password)
|
salt = secrets.token_bytes(16)
|
||||||
except:
|
base64_salt = base64.b64encode(salt).decode()
|
||||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
|
||||||
return
|
|
||||||
|
|
||||||
# generate password salt
|
# encrypt password with salt
|
||||||
salt = secrets.token_bytes(16)
|
password_hashed = hash_password(new_password, salt)
|
||||||
base64_salt = base64.b64encode(salt).decode()
|
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||||
|
account.password = base64_password_hashed
|
||||||
# encrypt password with salt
|
account.password_salt = base64_salt
|
||||||
password_hashed = hash_password(new_password, salt)
|
AccountService.reset_login_error_rate_limit(email)
|
||||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||||
account.password = base64_password_hashed
|
|
||||||
account.password_salt = base64_salt
|
|
||||||
db.session.commit()
|
|
||||||
AccountService.reset_login_error_rate_limit(email)
|
|
||||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
|
||||||
|
|
||||||
|
|
||||||
@click.command("reset-email", help="Reset the account email.")
|
@click.command("reset-email", help="Reset the account email.")
|
||||||
@ -100,22 +100,21 @@ def reset_email(email, new_email, email_confirm):
|
|||||||
if str(new_email).strip() != str(email_confirm).strip():
|
if str(new_email).strip() != str(email_confirm).strip():
|
||||||
click.echo(click.style("New emails do not match.", fg="red"))
|
click.echo(click.style("New emails do not match.", fg="red"))
|
||||||
return
|
return
|
||||||
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
|
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||||
|
|
||||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
if not account:
|
||||||
|
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||||
|
return
|
||||||
|
|
||||||
if not account:
|
try:
|
||||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
email_validate(new_email)
|
||||||
return
|
except:
|
||||||
|
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
account.email = new_email
|
||||||
email_validate(new_email)
|
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||||
except:
|
|
||||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
|
||||||
return
|
|
||||||
|
|
||||||
account.email = new_email
|
|
||||||
db.session.commit()
|
|
||||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
|
||||||
|
|
||||||
|
|
||||||
@click.command(
|
@click.command(
|
||||||
@ -139,25 +138,24 @@ def reset_encrypt_key_pair():
|
|||||||
if dify_config.EDITION != "SELF_HOSTED":
|
if dify_config.EDITION != "SELF_HOSTED":
|
||||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||||
return
|
return
|
||||||
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
|
tenants = session.query(Tenant).all()
|
||||||
|
for tenant in tenants:
|
||||||
|
if not tenant:
|
||||||
|
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||||
|
return
|
||||||
|
|
||||||
tenants = db.session.query(Tenant).all()
|
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||||
for tenant in tenants:
|
|
||||||
if not tenant:
|
|
||||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
|
||||||
return
|
|
||||||
|
|
||||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||||
|
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||||
|
|
||||||
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
click.echo(
|
||||||
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
click.style(
|
||||||
db.session.commit()
|
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||||
|
fg="green",
|
||||||
click.echo(
|
)
|
||||||
click.style(
|
|
||||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
|
||||||
fg="green",
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@click.command("vdb-migrate", help="Migrate vector db.")
|
@click.command("vdb-migrate", help="Migrate vector db.")
|
||||||
@ -182,14 +180,15 @@ def migrate_annotation_vector_database():
|
|||||||
try:
|
try:
|
||||||
# get apps info
|
# get apps info
|
||||||
per_page = 50
|
per_page = 50
|
||||||
apps = (
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
db.session.query(App)
|
apps = (
|
||||||
.where(App.status == "normal")
|
session.query(App)
|
||||||
.order_by(App.created_at.desc())
|
.where(App.status == "normal")
|
||||||
.limit(per_page)
|
.order_by(App.created_at.desc())
|
||||||
.offset((page - 1) * per_page)
|
.limit(per_page)
|
||||||
.all()
|
.offset((page - 1) * per_page)
|
||||||
)
|
.all()
|
||||||
|
)
|
||||||
if not apps:
|
if not apps:
|
||||||
break
|
break
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
@ -203,26 +202,27 @@ def migrate_annotation_vector_database():
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
click.echo(f"Creating app annotation index: {app.id}")
|
click.echo(f"Creating app annotation index: {app.id}")
|
||||||
app_annotation_setting = (
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
app_annotation_setting = (
|
||||||
)
|
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||||
|
)
|
||||||
|
|
||||||
if not app_annotation_setting:
|
if not app_annotation_setting:
|
||||||
skipped_count = skipped_count + 1
|
skipped_count = skipped_count + 1
|
||||||
click.echo(f"App annotation setting disabled: {app.id}")
|
click.echo(f"App annotation setting disabled: {app.id}")
|
||||||
continue
|
continue
|
||||||
# get dataset_collection_binding info
|
# get dataset_collection_binding info
|
||||||
dataset_collection_binding = (
|
dataset_collection_binding = (
|
||||||
db.session.query(DatasetCollectionBinding)
|
session.query(DatasetCollectionBinding)
|
||||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not dataset_collection_binding:
|
if not dataset_collection_binding:
|
||||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||||
continue
|
continue
|
||||||
annotations = db.session.scalars(
|
annotations = session.scalars(
|
||||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||||
).all()
|
).all()
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
id=app.id,
|
id=app.id,
|
||||||
tenant_id=app.tenant_id,
|
tenant_id=app.tenant_id,
|
||||||
@ -321,6 +321,8 @@ def migrate_knowledge_vector_database():
|
|||||||
)
|
)
|
||||||
|
|
||||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||||
|
if not datasets.items:
|
||||||
|
break
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -1420,7 +1422,10 @@ def setup_datasource_oauth_client(provider, client_params):
|
|||||||
|
|
||||||
|
|
||||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||||
def transform_datasource_credentials():
|
@click.option(
|
||||||
|
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
|
||||||
|
)
|
||||||
|
def transform_datasource_credentials(environment: str):
|
||||||
"""
|
"""
|
||||||
Transform datasource credentials
|
Transform datasource credentials
|
||||||
"""
|
"""
|
||||||
@ -1431,9 +1436,14 @@ def transform_datasource_credentials():
|
|||||||
notion_plugin_id = "langgenius/notion_datasource"
|
notion_plugin_id = "langgenius/notion_datasource"
|
||||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||||
jina_plugin_id = "langgenius/jina_datasource"
|
jina_plugin_id = "langgenius/jina_datasource"
|
||||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
if environment == "online":
|
||||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||||
|
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||||
|
else:
|
||||||
|
notion_plugin_unique_identifier = None
|
||||||
|
firecrawl_plugin_unique_identifier = None
|
||||||
|
jina_plugin_unique_identifier = None
|
||||||
oauth_credential_type = CredentialType.OAUTH2
|
oauth_credential_type = CredentialType.OAUTH2
|
||||||
api_key_credential_type = CredentialType.API_KEY
|
api_key_credential_type = CredentialType.API_KEY
|
||||||
|
|
||||||
@ -1521,6 +1531,14 @@ def transform_datasource_credentials():
|
|||||||
auth_count = 0
|
auth_count = 0
|
||||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||||
auth_count += 1
|
auth_count += 1
|
||||||
|
if not firecrawl_tenant_credential.credentials:
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
|
||||||
|
fg="yellow",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
# get credential api key
|
# get credential api key
|
||||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||||
api_key = credentials_json.get("config", {}).get("api_key")
|
api_key = credentials_json.get("config", {}).get("api_key")
|
||||||
@ -1576,6 +1594,14 @@ def transform_datasource_credentials():
|
|||||||
auth_count = 0
|
auth_count = 0
|
||||||
for jina_tenant_credential in jina_tenant_credentials:
|
for jina_tenant_credential in jina_tenant_credentials:
|
||||||
auth_count += 1
|
auth_count += 1
|
||||||
|
if not jina_tenant_credential.credentials:
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
|
||||||
|
fg="yellow",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
# get credential api key
|
# get credential api key
|
||||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||||
api_key = credentials_json.get("config", {}).get("api_key")
|
api_key = credentials_json.get("config", {}).get("api_key")
|
||||||
@ -1583,7 +1609,7 @@ def transform_datasource_credentials():
|
|||||||
"integration_secret": api_key,
|
"integration_secret": api_key,
|
||||||
}
|
}
|
||||||
datasource_provider = DatasourceProvider(
|
datasource_provider = DatasourceProvider(
|
||||||
provider="jina",
|
provider="jinareader",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
plugin_id=jina_plugin_id,
|
plugin_id=jina_plugin_id,
|
||||||
auth_type=api_key_credential_type.value,
|
auth_type=api_key_credential_type.value,
|
||||||
|
|||||||
@ -113,6 +113,21 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||||||
default=10.0,
|
default=10.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CODE_EXECUTION_POOL_MAX_CONNECTIONS: PositiveInt = Field(
|
||||||
|
description="Maximum number of concurrent connections for the code execution HTTP client",
|
||||||
|
default=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
|
||||||
|
description="Maximum number of persistent keep-alive connections for the code execution HTTP client",
|
||||||
|
default=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
|
||||||
|
description="Keep-alive expiry in seconds for idle connections (set to None to disable)",
|
||||||
|
default=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||||
description="Maximum allowed numeric value in code execution",
|
description="Maximum allowed numeric value in code execution",
|
||||||
default=9223372036854775807,
|
default=9223372036854775807,
|
||||||
@ -135,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||||||
|
|
||||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||||
description="Maximum allowed length for strings in code execution",
|
description="Maximum allowed length for strings in code execution",
|
||||||
default=80000,
|
default=400_000,
|
||||||
)
|
)
|
||||||
|
|
||||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||||
@ -153,6 +168,11 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||||||
default=1000,
|
default=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
CODE_EXECUTION_SSL_VERIFY: bool = Field(
|
||||||
|
description="Enable or disable SSL verification for code execution requests",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PluginConfig(BaseSettings):
|
class PluginConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
@ -169,6 +189,11 @@ class PluginConfig(BaseSettings):
|
|||||||
default="plugin-api-key",
|
default="plugin-api-key",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||||
|
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||||
|
default=300.0,
|
||||||
|
)
|
||||||
|
|
||||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||||
|
|
||||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||||
@ -306,12 +331,42 @@ class FileUploadConfig(BaseSettings):
|
|||||||
default=10,
|
default=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||||
|
description=(
|
||||||
|
"Comma-separated list of file extensions that are blocked from upload. "
|
||||||
|
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
|
||||||
|
"Empty by default to allow all file types."
|
||||||
|
),
|
||||||
|
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
@computed_field # type: ignore[misc]
|
||||||
|
@property
|
||||||
|
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
|
||||||
|
"""
|
||||||
|
Parse and return the blacklist as a set of lowercase extensions.
|
||||||
|
Returns an empty set if no blacklist is configured.
|
||||||
|
"""
|
||||||
|
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
|
||||||
|
return set()
|
||||||
|
return {
|
||||||
|
ext.strip().lower().strip(".")
|
||||||
|
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
|
||||||
|
if ext.strip()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class HttpConfig(BaseSettings):
|
class HttpConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
HTTP-related configurations for the application
|
HTTP-related configurations for the application
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
COOKIE_DOMAIN: str = Field(
|
||||||
|
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
API_COMPRESSION_ENABLED: bool = Field(
|
API_COMPRESSION_ENABLED: bool = Field(
|
||||||
description="Enable or disable gzip compression for HTTP responses",
|
description="Enable or disable gzip compression for HTTP responses",
|
||||||
default=False,
|
default=False,
|
||||||
@ -342,11 +397,11 @@ class HttpConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
|
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
|
||||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
|
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
|
||||||
)
|
)
|
||||||
|
|
||||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
|
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
|
||||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
|
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
|
||||||
)
|
)
|
||||||
|
|
||||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||||
@ -404,6 +459,21 @@ class HttpConfig(BaseSettings):
|
|||||||
default=5,
|
default=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SSRF_POOL_MAX_CONNECTIONS: PositiveInt = Field(
|
||||||
|
description="Maximum number of concurrent connections for the SSRF HTTP client",
|
||||||
|
default=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
|
||||||
|
description="Maximum number of persistent keep-alive connections for the SSRF HTTP client",
|
||||||
|
default=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
SSRF_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
|
||||||
|
description="Keep-alive expiry in seconds for idle SSRF connections (set to None to disable)",
|
||||||
|
default=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
|
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
|
||||||
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
|
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
|
||||||
" when the app is behind a single trusted reverse proxy.",
|
" when the app is behind a single trusted reverse proxy.",
|
||||||
@ -508,7 +578,7 @@ class UpdateConfig(BaseSettings):
|
|||||||
|
|
||||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||||
# 100KB
|
# 1000 KiB
|
||||||
1024_000,
|
1024_000,
|
||||||
description="Maximum size for variable to trigger final truncation.",
|
description="Maximum size for variable to trigger final truncation.",
|
||||||
)
|
)
|
||||||
@ -542,16 +612,16 @@ class WorkflowConfig(BaseSettings):
|
|||||||
default=5,
|
default=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
|
|
||||||
description="Maximum allowed depth for nested parallel executions",
|
|
||||||
default=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||||
default=200 * 1024,
|
default=200 * 1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
|
||||||
|
description="Maximum number of characters allowed in Template Transform node output",
|
||||||
|
default=400_000,
|
||||||
|
)
|
||||||
|
|
||||||
# GraphEngine Worker Pool Configuration
|
# GraphEngine Worker Pool Configuration
|
||||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||||
description="Minimum number of workers per GraphEngine instance",
|
description="Minimum number of workers per GraphEngine instance",
|
||||||
@ -736,7 +806,7 @@ class MailConfig(BaseSettings):
|
|||||||
|
|
||||||
MAIL_TEMPLATING_TIMEOUT: int = Field(
|
MAIL_TEMPLATING_TIMEOUT: int = Field(
|
||||||
description="""
|
description="""
|
||||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||||
Only available in sandbox mode.""",
|
Only available in sandbox mode.""",
|
||||||
default=3,
|
default=3,
|
||||||
)
|
)
|
||||||
@ -875,6 +945,11 @@ class DataSetConfig(BaseSettings):
|
|||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
|
||||||
|
description="Maximum number of segments for dataset segments API (0 for unlimited)",
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceConfig(BaseSettings):
|
class WorkspaceConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
@ -1067,6 +1142,13 @@ class SwaggerUIConfig(BaseSettings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||||
|
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||||
|
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FeatureConfig(
|
class FeatureConfig(
|
||||||
# place the configs in alphabet order
|
# place the configs in alphabet order
|
||||||
AppExecutionConfig,
|
AppExecutionConfig,
|
||||||
@ -1091,6 +1173,7 @@ class FeatureConfig(
|
|||||||
RagEtlConfig,
|
RagEtlConfig,
|
||||||
RepositoryConfig,
|
RepositoryConfig,
|
||||||
SecurityConfig,
|
SecurityConfig,
|
||||||
|
TenantIsolatedTaskQueueConfig,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
UpdateConfig,
|
UpdateConfig,
|
||||||
WorkflowConfig,
|
WorkflowConfig,
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
|
|||||||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||||
|
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
|
||||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||||
from .vdb.chroma_config import ChromaConfig
|
from .vdb.chroma_config import ChromaConfig
|
||||||
@ -144,7 +145,7 @@ class DatabaseConfig(BaseSettings):
|
|||||||
default="postgresql",
|
default="postgresql",
|
||||||
)
|
)
|
||||||
|
|
||||||
@computed_field # type: ignore[misc]
|
@computed_field # type: ignore[prop-decorator]
|
||||||
@property
|
@property
|
||||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||||
db_extras = (
|
db_extras = (
|
||||||
@ -197,7 +198,7 @@ class DatabaseConfig(BaseSettings):
|
|||||||
default=os.cpu_count() or 1,
|
default=os.cpu_count() or 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@computed_field # type: ignore[misc]
|
@computed_field # type: ignore[prop-decorator]
|
||||||
@property
|
@property
|
||||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||||
# Parse DB_EXTRAS for 'options'
|
# Parse DB_EXTRAS for 'options'
|
||||||
@ -330,6 +331,7 @@ class MiddlewareConfig(
|
|||||||
ClickzettaConfig,
|
ClickzettaConfig,
|
||||||
HuaweiCloudConfig,
|
HuaweiCloudConfig,
|
||||||
MilvusConfig,
|
MilvusConfig,
|
||||||
|
AlibabaCloudMySQLConfig,
|
||||||
MyScaleConfig,
|
MyScaleConfig,
|
||||||
OpenSearchConfig,
|
OpenSearchConfig,
|
||||||
OracleConfig,
|
OracleConfig,
|
||||||
|
|||||||
54
api/configs/middleware/vdb/alibabacloud_mysql_config.py
Normal file
54
api/configs/middleware/vdb/alibabacloud_mysql_config.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from pydantic import Field, PositiveInt
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class AlibabaCloudMySQLConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration settings for AlibabaCloud MySQL vector database
|
||||||
|
"""
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_HOST: str = Field(
|
||||||
|
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
|
||||||
|
default="localhost",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
|
||||||
|
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
|
||||||
|
default=3306,
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_USER: str = Field(
|
||||||
|
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
|
||||||
|
default="root",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
|
||||||
|
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
|
||||||
|
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
|
||||||
|
default="dify",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
|
||||||
|
description="Maximum number of connections in the connection pool",
|
||||||
|
default=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
|
||||||
|
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
|
||||||
|
default="utf8mb4",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
|
||||||
|
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
|
||||||
|
"(e.g., 'cosine', 'euclidean')",
|
||||||
|
default="cosine",
|
||||||
|
)
|
||||||
|
|
||||||
|
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
|
||||||
|
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
|
||||||
|
default=6,
|
||||||
|
)
|
||||||
@ -40,8 +40,12 @@ class OceanBaseVectorConfig(BaseSettings):
|
|||||||
|
|
||||||
OCEANBASE_FULLTEXT_PARSER: str | None = Field(
|
OCEANBASE_FULLTEXT_PARSER: str | None = Field(
|
||||||
description=(
|
description=(
|
||||||
"Fulltext parser to use for text indexing. Options: 'japanese_ftparser' (Japanese), "
|
"Fulltext parser to use for text indexing. "
|
||||||
"'thai_ftparser' (Thai), 'ik' (Chinese). Default is 'ik'"
|
"Built-in options: 'ngram' (N-gram tokenizer for English/numbers), "
|
||||||
|
"'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), "
|
||||||
|
"'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). "
|
||||||
|
"External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), "
|
||||||
|
"'thai_ftparser' (Thai tokenizer). Default is 'ik'"
|
||||||
),
|
),
|
||||||
default="ik",
|
default="ik",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,23 +1,24 @@
|
|||||||
from enum import Enum
|
from enum import StrEnum
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field, PositiveInt
|
from pydantic import Field, PositiveInt
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMethod(StrEnum):
|
||||||
|
"""
|
||||||
|
Authentication method for OpenSearch
|
||||||
|
"""
|
||||||
|
|
||||||
|
BASIC = "basic"
|
||||||
|
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||||
|
|
||||||
|
|
||||||
class OpenSearchConfig(BaseSettings):
|
class OpenSearchConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
Configuration settings for OpenSearch
|
Configuration settings for OpenSearch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class AuthMethod(Enum):
|
|
||||||
"""
|
|
||||||
Authentication method for OpenSearch
|
|
||||||
"""
|
|
||||||
|
|
||||||
BASIC = "basic"
|
|
||||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
|
||||||
|
|
||||||
OPENSEARCH_HOST: str | None = Field(
|
OPENSEARCH_HOST: str | None = Field(
|
||||||
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
|
|||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||||
|
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||||
default=100,
|
default=100,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from libs.collection_utils import convert_to_lower_and_upper_set
|
||||||
|
|
||||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||||
UNKNOWN_VALUE = "[__UNKNOWN__]"
|
UNKNOWN_VALUE = "[__UNKNOWN__]"
|
||||||
@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
|||||||
|
|
||||||
DEFAULT_FILE_NUMBER_LIMITS = 3
|
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
|
||||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
|
||||||
|
|
||||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
|
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
|
||||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
|
||||||
|
|
||||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
|
||||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
|
||||||
|
|
||||||
|
_doc_extensions: set[str]
|
||||||
_doc_extensions: list[str]
|
|
||||||
if dify_config.ETL_TYPE == "Unstructured":
|
if dify_config.ETL_TYPE == "Unstructured":
|
||||||
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
_doc_extensions = {
|
||||||
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
"txt",
|
||||||
|
"markdown",
|
||||||
|
"md",
|
||||||
|
"mdx",
|
||||||
|
"pdf",
|
||||||
|
"html",
|
||||||
|
"htm",
|
||||||
|
"xlsx",
|
||||||
|
"xls",
|
||||||
|
"vtt",
|
||||||
|
"properties",
|
||||||
|
"doc",
|
||||||
|
"docx",
|
||||||
|
"csv",
|
||||||
|
"eml",
|
||||||
|
"msg",
|
||||||
|
"pptx",
|
||||||
|
"xml",
|
||||||
|
"epub",
|
||||||
|
}
|
||||||
if dify_config.UNSTRUCTURED_API_URL:
|
if dify_config.UNSTRUCTURED_API_URL:
|
||||||
_doc_extensions.append("ppt")
|
_doc_extensions.add("ppt")
|
||||||
else:
|
else:
|
||||||
_doc_extensions = [
|
_doc_extensions = {
|
||||||
"txt",
|
"txt",
|
||||||
"markdown",
|
"markdown",
|
||||||
"md",
|
"md",
|
||||||
@ -37,5 +53,18 @@ else:
|
|||||||
"csv",
|
"csv",
|
||||||
"vtt",
|
"vtt",
|
||||||
"properties",
|
"properties",
|
||||||
]
|
}
|
||||||
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
|
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||||
|
|
||||||
|
# console
|
||||||
|
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||||
|
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
||||||
|
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
||||||
|
|
||||||
|
# webapp
|
||||||
|
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
|
||||||
|
COOKIE_NAME_PASSPORT = "passport"
|
||||||
|
|
||||||
|
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
||||||
|
HEADER_NAME_APP_CODE = "X-App-Code"
|
||||||
|
HEADER_NAME_PASSPORT = "X-App-Passport"
|
||||||
|
|||||||
@ -31,3 +31,9 @@ def supported_language(lang):
|
|||||||
|
|
||||||
error = f"{lang} is not a valid language."
|
error = f"{lang} is not a valid language."
|
||||||
raise ValueError(error)
|
raise ValueError(error)
|
||||||
|
|
||||||
|
|
||||||
|
def get_valid_language(lang: str | None) -> str:
|
||||||
|
if lang and lang in languages:
|
||||||
|
return lang
|
||||||
|
return languages[0]
|
||||||
|
|||||||
@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
|||||||
code = 415
|
code = 415
|
||||||
|
|
||||||
|
|
||||||
|
class BlockedFileExtensionError(BaseHTTPException):
|
||||||
|
error_code = "file_extension_blocked"
|
||||||
|
description = "The file extension is blocked for security reasons."
|
||||||
|
code = 400
|
||||||
|
|
||||||
|
|
||||||
class TooManyFilesError(BaseHTTPException):
|
class TooManyFilesError(BaseHTTPException):
|
||||||
error_code = "too_many_files"
|
error_code = "too_many_files"
|
||||||
description = "Only one file is allowed."
|
description = "Only one file is allowed."
|
||||||
|
|||||||
@ -24,7 +24,7 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||||
magic = None # type: ignore
|
magic = None # type: ignore[assignment]
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|||||||
@ -1,31 +1,10 @@
|
|||||||
|
from importlib import import_module
|
||||||
|
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from flask_restx import Namespace
|
from flask_restx import Namespace
|
||||||
|
|
||||||
from libs.external_api import ExternalApi
|
from libs.external_api import ExternalApi
|
||||||
|
|
||||||
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
|
|
||||||
from .explore.audio import ChatAudioApi, ChatTextApi
|
|
||||||
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
|
||||||
from .explore.conversation import (
|
|
||||||
ConversationApi,
|
|
||||||
ConversationListApi,
|
|
||||||
ConversationPinApi,
|
|
||||||
ConversationRenameApi,
|
|
||||||
ConversationUnPinApi,
|
|
||||||
)
|
|
||||||
from .explore.message import (
|
|
||||||
MessageFeedbackApi,
|
|
||||||
MessageListApi,
|
|
||||||
MessageMoreLikeThisApi,
|
|
||||||
MessageSuggestedQuestionApi,
|
|
||||||
)
|
|
||||||
from .explore.workflow import (
|
|
||||||
InstalledAppWorkflowRunApi,
|
|
||||||
InstalledAppWorkflowTaskStopApi,
|
|
||||||
)
|
|
||||||
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
|
||||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
|
||||||
|
|
||||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||||
|
|
||||||
api = ExternalApi(
|
api = ExternalApi(
|
||||||
@ -35,23 +14,23 @@ api = ExternalApi(
|
|||||||
description="Console management APIs for app configuration, monitoring, and administration",
|
description="Console management APIs for app configuration, monitoring, and administration",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create namespace
|
|
||||||
console_ns = Namespace("console", description="Console management API operations", path="/")
|
console_ns = Namespace("console", description="Console management API operations", path="/")
|
||||||
|
|
||||||
# File
|
RESOURCE_MODULES = (
|
||||||
api.add_resource(FileApi, "/files/upload")
|
"controllers.console.app.app_import",
|
||||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
"controllers.console.explore.audio",
|
||||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
"controllers.console.explore.completion",
|
||||||
|
"controllers.console.explore.conversation",
|
||||||
|
"controllers.console.explore.message",
|
||||||
|
"controllers.console.explore.workflow",
|
||||||
|
"controllers.console.files",
|
||||||
|
"controllers.console.remote_files",
|
||||||
|
)
|
||||||
|
|
||||||
# Remote files
|
for module_name in RESOURCE_MODULES:
|
||||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
import_module(module_name)
|
||||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
|
||||||
|
|
||||||
# Import App
|
|
||||||
api.add_resource(AppImportApi, "/apps/imports")
|
|
||||||
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
|
|
||||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
|
||||||
|
|
||||||
|
# Ensure resource modules are imported so route decorators are evaluated.
|
||||||
# Import other controllers
|
# Import other controllers
|
||||||
from . import (
|
from . import (
|
||||||
admin,
|
admin,
|
||||||
@ -150,77 +129,6 @@ from .workspace import (
|
|||||||
workspace,
|
workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Explore Audio
|
|
||||||
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
|
|
||||||
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
|
|
||||||
|
|
||||||
# Explore Completion
|
|
||||||
api.add_resource(
|
|
||||||
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
CompletionStopApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
|
||||||
endpoint="installed_app_stop_completion",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
ChatStopApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
|
||||||
endpoint="installed_app_stop_chat_completion",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Explore Conversation
|
|
||||||
api.add_resource(
|
|
||||||
ConversationRenameApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
|
||||||
endpoint="installed_app_conversation_rename",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
ConversationApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
|
||||||
endpoint="installed_app_conversation",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
ConversationPinApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
|
||||||
endpoint="installed_app_conversation_pin",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
ConversationUnPinApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
|
||||||
endpoint="installed_app_conversation_unpin",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Explore Message
|
|
||||||
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
|
|
||||||
api.add_resource(
|
|
||||||
MessageFeedbackApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
|
||||||
endpoint="installed_app_message_feedback",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
MessageMoreLikeThisApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
|
||||||
endpoint="installed_app_more_like_this",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
MessageSuggestedQuestionApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
|
||||||
endpoint="installed_app_suggested_question",
|
|
||||||
)
|
|
||||||
# Explore Workflow
|
|
||||||
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
|
|
||||||
api.add_resource(
|
|
||||||
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_namespace(console_ns)
|
api.add_namespace(console_ns)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from constants.languages import supported_language
|
|||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import only_edition_cloud
|
from controllers.console.wraps import only_edition_cloud
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from libs.token import extract_access_token
|
||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models.model import App, InstalledApp, RecommendedApp
|
||||||
|
|
||||||
|
|
||||||
@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
|
|||||||
if not dify_config.ADMIN_API_KEY:
|
if not dify_config.ADMIN_API_KEY:
|
||||||
raise Unauthorized("API key is invalid.")
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
auth_header = request.headers.get("Authorization")
|
auth_token = extract_access_token(request)
|
||||||
if auth_header is None:
|
if not auth_token:
|
||||||
raise Unauthorized("Authorization header is missing.")
|
raise Unauthorized("Authorization header is missing.")
|
||||||
|
|
||||||
if " " not in auth_header:
|
|
||||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
|
||||||
|
|
||||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
|
||||||
auth_scheme = auth_scheme.lower()
|
|
||||||
|
|
||||||
if auth_scheme != "bearer":
|
|
||||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
|
||||||
|
|
||||||
if auth_token != dify_config.ADMIN_API_KEY:
|
if auth_token != dify_config.ADMIN_API_KEY:
|
||||||
raise Unauthorized("API key is invalid.")
|
raise Unauthorized("API key is invalid.")
|
||||||
|
|
||||||
@ -70,15 +61,17 @@ class InsertExploreAppListApi(Resource):
|
|||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
@admin_required
|
@admin_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("desc", type=str, location="json")
|
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("copyright", type=str, location="json")
|
.add_argument("desc", type=str, location="json")
|
||||||
parser.add_argument("privacy_policy", type=str, location="json")
|
.add_argument("copyright", type=str, location="json")
|
||||||
parser.add_argument("custom_disclaimer", type=str, location="json")
|
.add_argument("privacy_policy", type=str, location="json")
|
||||||
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
.add_argument("custom_disclaimer", type=str, location="json")
|
||||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import flask_restx
|
import flask_restx
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with
|
from flask_restx import Resource, fields, marshal_with
|
||||||
from flask_restx._http import HTTPStatus
|
from flask_restx._http import HTTPStatus
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@ -8,12 +7,12 @@ from werkzeug.exceptions import Forbidden
|
|||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import TimestampField
|
from libs.helper import TimestampField
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from models.model import ApiToken, App
|
from models.model import ApiToken, App
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api, console_ns
|
||||||
from .wraps import account_initialization_required, setup_required
|
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
|
|
||||||
api_key_fields = {
|
api_key_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
@ -57,7 +56,9 @@ class BaseApiKeyListResource(Resource):
|
|||||||
def get(self, resource_id):
|
def get(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||||
keys = db.session.scalars(
|
keys = db.session.scalars(
|
||||||
select(ApiToken).where(
|
select(ApiToken).where(
|
||||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||||
@ -66,13 +67,12 @@ class BaseApiKeyListResource(Resource):
|
|||||||
return {"items": keys}
|
return {"items": keys}
|
||||||
|
|
||||||
@marshal_with(api_key_fields)
|
@marshal_with(api_key_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, resource_id):
|
def post(self, resource_id):
|
||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
if not current_user.is_editor:
|
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
current_key_count = (
|
current_key_count = (
|
||||||
db.session.query(ApiToken)
|
db.session.query(ApiToken)
|
||||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||||
@ -89,7 +89,7 @@ class BaseApiKeyListResource(Resource):
|
|||||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||||
api_token = ApiToken()
|
api_token = ApiToken()
|
||||||
setattr(api_token, self.resource_id_field, resource_id)
|
setattr(api_token, self.resource_id_field, resource_id)
|
||||||
api_token.tenant_id = current_user.current_tenant_id
|
api_token.tenant_id = current_tenant_id
|
||||||
api_token.token = key
|
api_token.token = key
|
||||||
api_token.type = self.resource_type
|
api_token.type = self.resource_type
|
||||||
db.session.add(api_token)
|
db.session.add(api_token)
|
||||||
@ -108,7 +108,8 @@ class BaseApiKeyResource(Resource):
|
|||||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||||
resource_id = str(resource_id)
|
resource_id = str(resource_id)
|
||||||
api_key_id = str(api_key_id)
|
api_key_id = str(api_key_id)
|
||||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
@ -152,11 +153,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
|||||||
"""Create a new API key for an app"""
|
"""Create a new API key for an app"""
|
||||||
return super().post(resource_id)
|
return super().post(resource_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "app"
|
resource_type = "app"
|
||||||
resource_model = App
|
resource_model = App
|
||||||
resource_id_field = "app_id"
|
resource_id_field = "app_id"
|
||||||
@ -173,11 +169,6 @@ class AppApiKeyResource(BaseApiKeyResource):
|
|||||||
"""Delete an API key for an app"""
|
"""Delete an API key for an app"""
|
||||||
return super().delete(resource_id, api_key_id)
|
return super().delete(resource_id, api_key_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "app"
|
resource_type = "app"
|
||||||
resource_model = App
|
resource_model = App
|
||||||
resource_id_field = "app_id"
|
resource_id_field = "app_id"
|
||||||
@ -202,11 +193,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
|||||||
"""Create a new API key for a dataset"""
|
"""Create a new API key for a dataset"""
|
||||||
return super().post(resource_id)
|
return super().post(resource_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "dataset"
|
resource_type = "dataset"
|
||||||
resource_model = Dataset
|
resource_model = Dataset
|
||||||
resource_id_field = "dataset_id"
|
resource_id_field = "dataset_id"
|
||||||
@ -223,11 +209,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
|||||||
"""Delete an API key for a dataset"""
|
"""Delete an API key for a dataset"""
|
||||||
return super().delete(resource_id, api_key_id)
|
return super().delete(resource_id, api_key_id)
|
||||||
|
|
||||||
def after_request(self, resp):
|
|
||||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
|
||||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
|
||||||
return resp
|
|
||||||
|
|
||||||
resource_type = "dataset"
|
resource_type = "dataset"
|
||||||
resource_model = Dataset
|
resource_model = Dataset
|
||||||
resource_id_field = "dataset_id"
|
resource_id_field = "dataset_id"
|
||||||
|
|||||||
@ -25,11 +25,13 @@ class AdvancedPromptTemplateList(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("app_mode", type=str, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("model_mode", type=str, required=True, location="args")
|
.add_argument("app_mode", type=str, required=True, location="args")
|
||||||
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
|
.add_argument("model_mode", type=str, required=True, location="args")
|
||||||
parser.add_argument("model_name", type=str, required=True, location="args")
|
.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||||
|
.add_argument("model_name", type=str, required=True, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return AdvancedPromptTemplateService.get_prompt(args)
|
return AdvancedPromptTemplateService.get_prompt(args)
|
||||||
|
|||||||
@ -27,9 +27,11 @@ class AgentLogApi(Resource):
|
|||||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Get agent logs"""
|
"""Get agent logs"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||||
|
.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,14 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
|
||||||
|
|
||||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
|
edit_permission_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
@ -17,6 +16,7 @@ from fields.annotation_fields import (
|
|||||||
annotation_fields,
|
annotation_fields,
|
||||||
annotation_hit_history_fields,
|
annotation_hit_history_fields,
|
||||||
)
|
)
|
||||||
|
from libs.helper import uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
|
|
||||||
@ -42,15 +42,15 @@ class AnnotationReplyActionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
.add_argument("score_threshold", required=True, type=float, location="json")
|
||||||
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
|
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||||
|
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if action == "enable":
|
if action == "enable":
|
||||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||||
@ -69,10 +69,8 @@ class AppAnnotationSettingDetailApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||||
return result, 200
|
return result, 200
|
||||||
@ -98,15 +96,12 @@ class AppAnnotationSettingUpdateApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id, annotation_setting_id):
|
def post(self, app_id, annotation_setting_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_setting_id = str(annotation_setting_id)
|
annotation_setting_id = str(annotation_setting_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
||||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||||
@ -124,10 +119,8 @@ class AnnotationReplyActionStatusApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id, job_id, action):
|
def get(self, app_id, job_id, action):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||||
cache_result = redis_client.get(app_annotation_job_key)
|
cache_result = redis_client.get(app_annotation_job_key)
|
||||||
@ -159,10 +152,8 @@ class AnnotationApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
keyword = request.args.get("keyword", default="", type=str)
|
keyword = request.args.get("keyword", default="", type=str)
|
||||||
@ -185,8 +176,10 @@ class AnnotationApi(Resource):
|
|||||||
api.model(
|
api.model(
|
||||||
"CreateAnnotationRequest",
|
"CreateAnnotationRequest",
|
||||||
{
|
{
|
||||||
"question": fields.String(required=True, description="Question text"),
|
"message_id": fields.String(description="Message ID (optional)"),
|
||||||
"answer": fields.String(required=True, description="Answer text"),
|
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||||
|
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||||
|
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -198,25 +191,26 @@ class AnnotationApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("question", required=True, type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("answer", required=True, type=str, location="json")
|
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||||
|
.add_argument("question", required=False, type=str, location="json")
|
||||||
|
.add_argument("answer", required=False, type=str, location="json")
|
||||||
|
.add_argument("content", required=False, type=str, location="json")
|
||||||
|
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
|
|
||||||
# Use request.args.getlist to get annotation_ids array directly
|
# Use request.args.getlist to get annotation_ids array directly
|
||||||
@ -249,10 +243,8 @@ class AnnotationExportApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||||
response = {"data": marshal(annotation_list, annotation_fields)}
|
response = {"data": marshal(annotation_list, annotation_fields)}
|
||||||
@ -271,16 +263,16 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(annotation_fields)
|
@marshal_with(annotation_fields)
|
||||||
def post(self, app_id, annotation_id):
|
def post(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("question", required=True, type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("answer", required=True, type=str, location="json")
|
.add_argument("question", required=True, type=str, location="json")
|
||||||
|
.add_argument("answer", required=True, type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||||
return annotation
|
return annotation
|
||||||
@ -288,10 +280,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_id, annotation_id):
|
def delete(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
annotation_id = str(annotation_id)
|
annotation_id = str(annotation_id)
|
||||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||||
@ -310,10 +300,8 @@ class AnnotationBatchImportApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
# check file
|
# check file
|
||||||
if "file" not in request.files:
|
if "file" not in request.files:
|
||||||
@ -341,10 +329,8 @@ class AnnotationBatchImportStatusApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
@cloud_edition_billing_resource_check("annotation")
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id, job_id):
|
def get(self, app_id, job_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
@ -376,10 +362,8 @@ class AnnotationHitHistoryListApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_id, annotation_id):
|
def get(self, app_id, annotation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -12,14 +10,16 @@ from controllers.console.app.wraps import get_app_model
|
|||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
|
edit_permission_required,
|
||||||
enterprise_license_required,
|
enterprise_license_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.ops.ops_trace_manager import OpsTraceManager
|
from core.ops.ops_trace_manager import OpsTraceManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, App
|
from libs.validators import validate_description_length
|
||||||
|
from models import App
|
||||||
from services.app_dsl_service import AppDslService, ImportMode
|
from services.app_dsl_service import AppDslService, ImportMode
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
@ -28,12 +28,6 @@ from services.feature_service import FeatureService
|
|||||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||||
|
|
||||||
|
|
||||||
def _validate_description_length(description):
|
|
||||||
if description and len(description) > 400:
|
|
||||||
raise ValueError("Description cannot exceed 400 characters.")
|
|
||||||
return description
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps")
|
@console_ns.route("/apps")
|
||||||
class AppListApi(Resource):
|
class AppListApi(Resource):
|
||||||
@api.doc("list_apps")
|
@api.doc("list_apps")
|
||||||
@ -61,6 +55,7 @@ class AppListApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get app list"""
|
"""Get app list"""
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
def uuid_list(value):
|
def uuid_list(value):
|
||||||
try:
|
try:
|
||||||
@ -68,34 +63,36 @@ class AppListApi(Resource):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
abort(400, message="Invalid UUID format in tag_ids.")
|
abort(400, message="Invalid UUID format in tag_ids.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||||
parser.add_argument(
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
"mode",
|
.add_argument(
|
||||||
type=str,
|
"mode",
|
||||||
choices=[
|
type=str,
|
||||||
"completion",
|
choices=[
|
||||||
"chat",
|
"completion",
|
||||||
"advanced-chat",
|
"chat",
|
||||||
"workflow",
|
"advanced-chat",
|
||||||
"agent-chat",
|
"workflow",
|
||||||
"channel",
|
"agent-chat",
|
||||||
"all",
|
"channel",
|
||||||
],
|
"all",
|
||||||
default="all",
|
],
|
||||||
location="args",
|
default="all",
|
||||||
required=False,
|
location="args",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
.add_argument("name", type=str, location="args", required=False)
|
||||||
|
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||||
|
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||||
)
|
)
|
||||||
parser.add_argument("name", type=str, location="args", required=False)
|
|
||||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
|
||||||
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# get app list
|
# get app list
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
|
||||||
if not app_pagination:
|
if not app_pagination:
|
||||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||||
|
|
||||||
@ -134,30 +131,26 @@ class AppListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
@cloud_edition_billing_resource_check("apps")
|
@cloud_edition_billing_resource_check("apps")
|
||||||
|
@edit_permission_required
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Create app"""
|
"""Create app"""
|
||||||
parser = reqparse.RequestParser()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
parser = (
|
||||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
.add_argument("name", type=str, required=True, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("description", type=validate_description_length, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
.add_argument("icon_type", type=str, location="json")
|
||||||
|
.add_argument("icon", type=str, location="json")
|
||||||
|
.add_argument("icon_background", type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if "mode" not in args or args["mode"] is None:
|
if "mode" not in args or args["mode"] is None:
|
||||||
raise BadRequest("mode is required")
|
raise BadRequest("mode is required")
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
if not isinstance(current_user, Account):
|
app = app_service.create_app(current_tenant_id, args, current_user)
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
if current_user.current_tenant_id is None:
|
|
||||||
raise ValueError("current_user.current_tenant_id cannot be None")
|
|
||||||
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
|
||||||
|
|
||||||
return app, 201
|
return app, 201
|
||||||
|
|
||||||
@ -210,21 +203,20 @@ class AppApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(app_detail_fields_with_site)
|
@marshal_with(app_detail_fields_with_site)
|
||||||
def put(self, app_model):
|
def put(self, app_model):
|
||||||
"""Update app"""
|
"""Update app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = (
|
||||||
if not current_user.is_editor:
|
reqparse.RequestParser()
|
||||||
raise Forbidden()
|
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("description", type=validate_description_length, location="json")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
.add_argument("icon_background", type=str, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("max_active_requests", type=int, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
)
|
||||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
|
||||||
parser.add_argument("max_active_requests", type=int, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -253,12 +245,9 @@ class AppApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_model):
|
def delete(self, app_model):
|
||||||
"""Delete app"""
|
"""Delete app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
app_service.delete_app(app_model)
|
app_service.delete_app(app_model)
|
||||||
|
|
||||||
@ -288,28 +277,29 @@ class AppCopyApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(app_detail_fields_with_site)
|
@marshal_with(app_detail_fields_with_site)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
"""Copy app"""
|
"""Copy app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
.add_argument("name", type=str, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("description", type=validate_description_length, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
.add_argument("icon", type=str, location="json")
|
||||||
|
.add_argument("icon_background", type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||||
account = cast(Account, current_user)
|
|
||||||
result = import_service.import_app(
|
result = import_service.import_app(
|
||||||
account=account,
|
account=current_user,
|
||||||
import_mode=ImportMode.YAML_CONTENT.value,
|
import_mode=ImportMode.YAML_CONTENT,
|
||||||
yaml_content=yaml_content,
|
yaml_content=yaml_content,
|
||||||
name=args.get("name"),
|
name=args.get("name"),
|
||||||
description=args.get("description"),
|
description=args.get("description"),
|
||||||
@ -345,16 +335,15 @@ class AppExportApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
"""Export app"""
|
"""Export app"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# Add include_secret params
|
# Add include_secret params
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("workflow_id", type=str, location="args")
|
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||||
|
.add_argument("workflow_id", type=str, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -376,13 +365,9 @@ class AppNameApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -413,14 +398,13 @@ class AppIconApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = (
|
||||||
if not current_user.is_editor:
|
reqparse.RequestParser()
|
||||||
raise Forbidden()
|
.add_argument("icon", type=str, location="json")
|
||||||
|
.add_argument("icon_background", type=str, location="json")
|
||||||
parser = reqparse.RequestParser()
|
)
|
||||||
parser.add_argument("icon", type=str, location="json")
|
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -446,13 +430,9 @@ class AppSiteStatus(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("enable_site", type=bool, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -480,11 +460,11 @@ class AppApiStatus(Resource):
|
|||||||
@marshal_with(app_detail_fields)
|
@marshal_with(app_detail_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||||
parser.add_argument("enable_api", type=bool, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app_service = AppService()
|
app_service = AppService()
|
||||||
@ -525,13 +505,14 @@ class AppTraceApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
# add app trace
|
# add app trace
|
||||||
if not current_user.is_editor:
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("enabled", type=bool, required=True, location="json")
|
||||||
parser.add_argument("enabled", type=bool, required=True, location="json")
|
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
OpsTraceManager.update_app_tracing_config(
|
OpsTraceManager.update_app_tracing_config(
|
||||||
|
|||||||
@ -1,54 +1,54 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
|
||||||
|
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_resource_check,
|
cloud_edition_billing_resource_check,
|
||||||
|
edit_permission_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account
|
|
||||||
from models.model import App
|
from models.model import App
|
||||||
from services.app_dsl_service import AppDslService, ImportStatus
|
from services.app_dsl_service import AppDslService, ImportStatus
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
from .. import console_ns
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/apps/imports")
|
||||||
class AppImportApi(Resource):
|
class AppImportApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_import_fields)
|
@marshal_with(app_import_fields)
|
||||||
@cloud_edition_billing_resource_check("apps")
|
@cloud_edition_billing_resource_check("apps")
|
||||||
|
@edit_permission_required
|
||||||
def post(self):
|
def post(self):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("mode", type=str, required=True, location="json")
|
||||||
parser.add_argument("mode", type=str, required=True, location="json")
|
.add_argument("yaml_content", type=str, location="json")
|
||||||
parser.add_argument("yaml_content", type=str, location="json")
|
.add_argument("yaml_url", type=str, location="json")
|
||||||
parser.add_argument("yaml_url", type=str, location="json")
|
.add_argument("name", type=str, location="json")
|
||||||
parser.add_argument("name", type=str, location="json")
|
.add_argument("description", type=str, location="json")
|
||||||
parser.add_argument("description", type=str, location="json")
|
.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("icon_background", type=str, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
.add_argument("app_id", type=str, location="json")
|
||||||
parser.add_argument("app_id", type=str, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
# Import app
|
# Import app
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.import_app(
|
result = import_service.import_app(
|
||||||
account=account,
|
account=account,
|
||||||
import_mode=args["mode"],
|
import_mode=args["mode"],
|
||||||
@ -67,47 +67,47 @@ class AppImportApi(Resource):
|
|||||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||||
# Return appropriate status code based on result
|
# Return appropriate status code based on result
|
||||||
status = result.status
|
status = result.status
|
||||||
if status == ImportStatus.FAILED.value:
|
if status == ImportStatus.FAILED:
|
||||||
return result.model_dump(mode="json"), 400
|
return result.model_dump(mode="json"), 400
|
||||||
elif status == ImportStatus.PENDING.value:
|
elif status == ImportStatus.PENDING:
|
||||||
return result.model_dump(mode="json"), 202
|
return result.model_dump(mode="json"), 202
|
||||||
return result.model_dump(mode="json"), 200
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||||
class AppImportConfirmApi(Resource):
|
class AppImportConfirmApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_import_fields)
|
@marshal_with(app_import_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, import_id):
|
def post(self, import_id):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
# Confirm import
|
# Confirm import
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Return appropriate status code based on result
|
# Return appropriate status code based on result
|
||||||
if result.status == ImportStatus.FAILED.value:
|
if result.status == ImportStatus.FAILED:
|
||||||
return result.model_dump(mode="json"), 400
|
return result.model_dump(mode="json"), 400
|
||||||
return result.model_dump(mode="json"), 200
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
|
||||||
class AppImportCheckDependenciesApi(Resource):
|
class AppImportCheckDependenciesApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_import_check_dependencies_fields)
|
@marshal_with(app_import_check_dependencies_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = AppDslService(session)
|
import_service = AppDslService(session)
|
||||||
result = import_service.check_dependencies(app_model=app_model)
|
result = import_service.check_dependencies(app_model=app_model)
|
||||||
|
|||||||
@ -111,11 +111,13 @@ class ChatMessageTextApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("message_id", type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("text", type=str, location="json")
|
.add_argument("message_id", type=str, location="json")
|
||||||
parser.add_argument("voice", type=str, location="json")
|
.add_argument("text", type=str, location="json")
|
||||||
parser.add_argument("streaming", type=bool, location="json")
|
.add_argument("voice", type=str, location="json")
|
||||||
|
.add_argument("streaming", type=bool, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get("message_id", None)
|
message_id = args.get("message_id", None)
|
||||||
@ -166,8 +168,7 @@ class TextModesApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
||||||
parser.add_argument("language", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
response = AudioService.transcript_tts_voices(
|
response = AudioService.transcript_tts_voices(
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import logging
|
|||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
@ -15,7 +15,7 @@ from controllers.console.app.error import (
|
|||||||
ProviderQuotaExceededError,
|
ProviderQuotaExceededError,
|
||||||
)
|
)
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
@ -64,13 +64,15 @@ class CompletionMessageApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("query", type=str, location="json", default="")
|
.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("query", type=str, location="json", default="")
|
||||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
.add_argument("model_config", type=dict, required=True, location="json")
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
|
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args["response_mode"] != "blocking"
|
streaming = args["response_mode"] != "blocking"
|
||||||
@ -151,22 +153,19 @@ class ChatMessageApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
|
.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("query", type=str, required=True, location="json")
|
||||||
raise Forbidden()
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
|
.add_argument("model_config", type=dict, required=True, location="json")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||||
parser.add_argument("query", type=str, required=True, location="json")
|
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
)
|
||||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
|
||||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
|
||||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args["response_mode"] != "blocking"
|
streaming = args["response_mode"] != "blocking"
|
||||||
|
|||||||
@ -1,16 +1,14 @@
|
|||||||
from datetime import datetime
|
import sqlalchemy as sa
|
||||||
|
from flask import abort
|
||||||
import pytz # pip install pytz
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy import func, or_
|
from sqlalchemy import func, or_
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import (
|
from fields.conversation_fields import (
|
||||||
@ -19,10 +17,10 @@ from fields.conversation_fields import (
|
|||||||
conversation_pagination_fields,
|
conversation_pagination_fields,
|
||||||
conversation_with_summary_pagination_fields,
|
conversation_with_summary_pagination_fields,
|
||||||
)
|
)
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, Conversation, EndUser, Message, MessageAnnotation
|
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
@ -56,21 +54,27 @@ class CompletionConversationApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
@marshal_with(conversation_pagination_fields)
|
@marshal_with(conversation_pagination_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
parser = reqparse.RequestParser()
|
reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument(
|
.add_argument(
|
||||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
"annotation_status",
|
||||||
|
type=str,
|
||||||
|
choices=["annotated", "not_annotated", "all"],
|
||||||
|
default="all",
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
|
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||||
)
|
)
|
||||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
|
||||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
query = db.select(Conversation).where(
|
query = sa.select(Conversation).where(
|
||||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -83,25 +87,18 @@ class CompletionConversationApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
timezone = pytz.timezone(account.timezone)
|
assert account.timezone is not None
|
||||||
utc_timezone = pytz.utc
|
|
||||||
|
|
||||||
if args["start"]:
|
try:
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
start_datetime = start_datetime.replace(second=0)
|
except ValueError as e:
|
||||||
|
abort(400, description=str(e))
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
|
if start_datetime_utc:
|
||||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||||
end_datetime = end_datetime.replace(second=59)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||||
|
|
||||||
# FIXME, the type ignore in this file
|
# FIXME, the type ignore in this file
|
||||||
@ -136,9 +133,8 @@ class CompletionConversationDetailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
@marshal_with(conversation_message_detail_fields)
|
@marshal_with(conversation_message_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model, conversation_id):
|
def get(self, app_model, conversation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
return _get_conversation(app_model, conversation_id)
|
return _get_conversation(app_model, conversation_id)
|
||||||
@ -153,14 +149,12 @@ class CompletionConversationDetailApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_model, conversation_id):
|
def delete(self, app_model, conversation_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
@ -205,26 +199,32 @@ class ChatConversationApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
@marshal_with(conversation_with_summary_pagination_fields)
|
@marshal_with(conversation_with_summary_pagination_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
parser = reqparse.RequestParser()
|
reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
.add_argument("keyword", type=str, location="args")
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
parser.add_argument(
|
.add_argument(
|
||||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
"annotation_status",
|
||||||
)
|
type=str,
|
||||||
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
choices=["annotated", "not_annotated", "all"],
|
||||||
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
default="all",
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
location="args",
|
||||||
parser.add_argument(
|
)
|
||||||
"sort_by",
|
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||||
type=str,
|
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
required=False,
|
.add_argument(
|
||||||
default="-updated_at",
|
"sort_by",
|
||||||
location="args",
|
type=str,
|
||||||
|
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||||
|
required=False,
|
||||||
|
default="-updated_at",
|
||||||
|
location="args",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -236,7 +236,7 @@ class ChatConversationApi(Resource):
|
|||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
|
|
||||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||||
|
|
||||||
if args["keyword"]:
|
if args["keyword"]:
|
||||||
keyword_filter = f"%{args['keyword']}%"
|
keyword_filter = f"%{args['keyword']}%"
|
||||||
@ -259,29 +259,22 @@ class ChatConversationApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
account = current_user
|
account = current_user
|
||||||
timezone = pytz.timezone(account.timezone)
|
assert account.timezone is not None
|
||||||
utc_timezone = pytz.utc
|
|
||||||
|
|
||||||
if args["start"]:
|
try:
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
start_datetime = start_datetime.replace(second=0)
|
except ValueError as e:
|
||||||
|
abort(400, description=str(e))
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
|
if start_datetime_utc:
|
||||||
match args["sort_by"]:
|
match args["sort_by"]:
|
||||||
case "updated_at" | "-updated_at":
|
case "updated_at" | "-updated_at":
|
||||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||||
case "created_at" | "-created_at" | _:
|
case "created_at" | "-created_at" | _:
|
||||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||||
end_datetime = end_datetime.replace(second=59)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
match args["sort_by"]:
|
match args["sort_by"]:
|
||||||
case "updated_at" | "-updated_at":
|
case "updated_at" | "-updated_at":
|
||||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||||
@ -308,7 +301,7 @@ class ChatConversationApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||||
|
|
||||||
match args["sort_by"]:
|
match args["sort_by"]:
|
||||||
case "created_at":
|
case "created_at":
|
||||||
@ -340,9 +333,8 @@ class ChatConversationDetailApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
@marshal_with(conversation_detail_fields)
|
@marshal_with(conversation_detail_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model, conversation_id):
|
def get(self, app_model, conversation_id):
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
return _get_conversation(app_model, conversation_id)
|
return _get_conversation(app_model, conversation_id)
|
||||||
@ -357,14 +349,12 @@ class ChatConversationDetailApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_model, conversation_id):
|
def delete(self, app_model, conversation_id):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
conversation_id = str(conversation_id)
|
conversation_id = str(conversation_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
ConversationService.delete(app_model, conversation_id, current_user)
|
ConversationService.delete(app_model, conversation_id, current_user)
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
@ -373,6 +363,7 @@ class ChatConversationDetailApi(Resource):
|
|||||||
|
|
||||||
|
|
||||||
def _get_conversation(app_model, conversation_id):
|
def _get_conversation(app_model, conversation_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
conversation = (
|
conversation = (
|
||||||
db.session.query(Conversation)
|
db.session.query(Conversation)
|
||||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||||
|
|||||||
@ -29,8 +29,7 @@ class ConversationVariablesApi(Resource):
|
|||||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||||
@marshal_with(paginated_conversation_variable_fields)
|
@marshal_with(paginated_conversation_variable_fields)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
|
||||||
parser.add_argument("conversation_id", type=str, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields, reqparse
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
@ -17,7 +16,7 @@ from core.helper.code_executor.python3.python3_code_provider import Python3CodeP
|
|||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import App
|
from models import App
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
@ -43,16 +42,18 @@ class RuleGenerateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
account = current_user
|
|
||||||
try:
|
try:
|
||||||
rules = LLMGenerator.generate_rule_config(
|
rules = LLMGenerator.generate_rule_config(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=args["no_variable"],
|
no_variable=args["no_variable"],
|
||||||
@ -93,17 +94,19 @@ class RuleCodeGenerateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||||
|
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
account = current_user
|
|
||||||
try:
|
try:
|
||||||
code_result = LLMGenerator.generate_code(
|
code_result = LLMGenerator.generate_code(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
code_language=args["code_language"],
|
code_language=args["code_language"],
|
||||||
@ -140,15 +143,17 @@ class RuleStructuredOutputGenerateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
account = current_user
|
|
||||||
try:
|
try:
|
||||||
structured_output = LLMGenerator.generate_structured_output(
|
structured_output = LLMGenerator.generate_structured_output(
|
||||||
tenant_id=account.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
)
|
)
|
||||||
@ -189,15 +194,18 @@ class InstructionGenerateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("node_id", type=str, required=False, default="", location="json")
|
.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||||
parser.add_argument("current", type=str, required=False, default="", location="json")
|
.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||||
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
|
.add_argument("current", type=str, required=False, default="", location="json")
|
||||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
code_template = (
|
code_template = (
|
||||||
Python3CodeProvider.get_default_code()
|
Python3CodeProvider.get_default_code()
|
||||||
if args["language"] == "python"
|
if args["language"] == "python"
|
||||||
@ -222,21 +230,21 @@ class InstructionGenerateApi(Resource):
|
|||||||
match node_type:
|
match node_type:
|
||||||
case "llm":
|
case "llm":
|
||||||
return LLMGenerator.generate_rule_config(
|
return LLMGenerator.generate_rule_config(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=True,
|
no_variable=True,
|
||||||
)
|
)
|
||||||
case "agent":
|
case "agent":
|
||||||
return LLMGenerator.generate_rule_config(
|
return LLMGenerator.generate_rule_config(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
no_variable=True,
|
no_variable=True,
|
||||||
)
|
)
|
||||||
case "code":
|
case "code":
|
||||||
return LLMGenerator.generate_code(
|
return LLMGenerator.generate_code(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
model_config=args["model_config"],
|
model_config=args["model_config"],
|
||||||
code_language=args["language"],
|
code_language=args["language"],
|
||||||
@ -245,7 +253,7 @@ class InstructionGenerateApi(Resource):
|
|||||||
return {"error": f"invalid node type: {node_type}"}
|
return {"error": f"invalid node type: {node_type}"}
|
||||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||||
return LLMGenerator.instruction_modify_legacy(
|
return LLMGenerator.instruction_modify_legacy(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
flow_id=args["flow_id"],
|
flow_id=args["flow_id"],
|
||||||
current=args["current"],
|
current=args["current"],
|
||||||
instruction=args["instruction"],
|
instruction=args["instruction"],
|
||||||
@ -254,7 +262,7 @@ class InstructionGenerateApi(Resource):
|
|||||||
)
|
)
|
||||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||||
return LLMGenerator.instruction_modify_workflow(
|
return LLMGenerator.instruction_modify_workflow(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
flow_id=args["flow_id"],
|
flow_id=args["flow_id"],
|
||||||
node_id=args["node_id"],
|
node_id=args["node_id"],
|
||||||
current=args["current"],
|
current=args["current"],
|
||||||
@ -293,8 +301,7 @@ class InstructionGenerationTemplateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
|
||||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
match args["type"]:
|
match args["type"]:
|
||||||
case "prompt":
|
case "prompt":
|
||||||
|
|||||||
@ -1,16 +1,15 @@
|
|||||||
import json
|
import json
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_server_fields
|
from fields.app_fields import app_server_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.model import AppMCPServer
|
from models.model import AppMCPServer
|
||||||
|
|
||||||
|
|
||||||
@ -25,9 +24,9 @@ class AppMCPServerController(Resource):
|
|||||||
@api.doc(description="Get MCP server configuration for an application")
|
@api.doc(description="Get MCP server configuration for an application")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
|
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
|
||||||
@setup_required
|
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@setup_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
@marshal_with(app_server_fields)
|
@marshal_with(app_server_fields)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
@ -48,17 +47,19 @@ class AppMCPServerController(Resource):
|
|||||||
)
|
)
|
||||||
@api.response(201, "MCP server configuration created successfully", app_server_fields)
|
@api.response(201, "MCP server configuration created successfully", app_server_fields)
|
||||||
@api.response(403, "Insufficient permissions")
|
@api.response(403, "Insufficient permissions")
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
@login_required
|
||||||
|
@setup_required
|
||||||
@marshal_with(app_server_fields)
|
@marshal_with(app_server_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise NotFound()
|
parser = (
|
||||||
parser = reqparse.RequestParser()
|
reqparse.RequestParser()
|
||||||
parser.add_argument("description", type=str, required=False, location="json")
|
.add_argument("description", type=str, required=False, location="json")
|
||||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
.add_argument("parameters", type=dict, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
description = args.get("description")
|
description = args.get("description")
|
||||||
@ -71,7 +72,7 @@ class AppMCPServerController(Resource):
|
|||||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||||
status=AppMCPServerStatus.ACTIVE,
|
status=AppMCPServerStatus.ACTIVE,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
server_code=AppMCPServer.generate_server_code(16),
|
server_code=AppMCPServer.generate_server_code(16),
|
||||||
)
|
)
|
||||||
db.session.add(server)
|
db.session.add(server)
|
||||||
@ -95,19 +96,20 @@ class AppMCPServerController(Resource):
|
|||||||
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
|
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
|
||||||
@api.response(403, "Insufficient permissions")
|
@api.response(403, "Insufficient permissions")
|
||||||
@api.response(404, "Server not found")
|
@api.response(404, "Server not found")
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@get_app_model
|
@get_app_model
|
||||||
|
@login_required
|
||||||
|
@setup_required
|
||||||
|
@account_initialization_required
|
||||||
@marshal_with(app_server_fields)
|
@marshal_with(app_server_fields)
|
||||||
|
@edit_permission_required
|
||||||
def put(self, app_model):
|
def put(self, app_model):
|
||||||
if not current_user.is_editor:
|
parser = (
|
||||||
raise NotFound()
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("id", type=str, required=True, location="json")
|
||||||
parser.add_argument("id", type=str, required=True, location="json")
|
.add_argument("description", type=str, required=False, location="json")
|
||||||
parser.add_argument("description", type=str, required=False, location="json")
|
.add_argument("parameters", type=dict, required=True, location="json")
|
||||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
.add_argument("status", type=str, required=False, location="json")
|
||||||
parser.add_argument("status", type=str, required=False, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
||||||
if not server:
|
if not server:
|
||||||
@ -142,13 +144,13 @@ class AppMCPServerRefreshController(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(app_server_fields)
|
@marshal_with(app_server_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, server_id):
|
def get(self, server_id):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise NotFound()
|
|
||||||
server = (
|
server = (
|
||||||
db.session.query(AppMCPServer)
|
db.session.query(AppMCPServer)
|
||||||
.where(AppMCPServer.id == server_id)
|
.where(AppMCPServer.id == server_id)
|
||||||
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
|
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not server:
|
if not server:
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import logging
|
|||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import exists, select
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
@ -16,20 +16,18 @@ from controllers.console.app.wraps import get_app_model
|
|||||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
cloud_edition_billing_resource_check,
|
edit_permission_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
from fields.conversation_fields import message_detail_fields
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
|
||||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||||
from services.annotation_service import AppAnnotationService
|
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||||
from services.message_service import MessageService
|
from services.message_service import MessageService
|
||||||
@ -56,19 +54,19 @@ class ChatMessageListApi(Resource):
|
|||||||
)
|
)
|
||||||
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
|
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
|
||||||
@api.response(404, "Conversation not found")
|
@api.response(404, "Conversation not found")
|
||||||
@setup_required
|
|
||||||
@login_required
|
@login_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@setup_required
|
||||||
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
|
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("first_id", type=uuid_value, location="args")
|
||||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
)
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
conversation = (
|
conversation = (
|
||||||
@ -154,12 +152,13 @@ class MessageFeedbackApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
if current_user is None:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||||
|
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = str(args["message_id"])
|
message_id = str(args["message_id"])
|
||||||
@ -193,47 +192,6 @@ class MessageFeedbackApi(Resource):
|
|||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
|
||||||
class MessageAnnotationApi(Resource):
|
|
||||||
@api.doc("create_message_annotation")
|
|
||||||
@api.doc(description="Create message annotation")
|
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
|
||||||
@api.expect(
|
|
||||||
api.model(
|
|
||||||
"MessageAnnotationRequest",
|
|
||||||
{
|
|
||||||
"message_id": fields.String(description="Message ID"),
|
|
||||||
"question": fields.String(required=True, description="Question text"),
|
|
||||||
"answer": fields.String(required=True, description="Answer text"),
|
|
||||||
"annotation_reply": fields.Raw(description="Annotation reply"),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
|
||||||
@api.response(403, "Insufficient permissions")
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@cloud_edition_billing_resource_check("annotation")
|
|
||||||
@get_app_model
|
|
||||||
@marshal_with(annotation_fields)
|
|
||||||
def post(self, app_model):
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
|
|
||||||
parser.add_argument("question", required=True, type=str, location="json")
|
|
||||||
parser.add_argument("answer", required=True, type=str, location="json")
|
|
||||||
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
|
||||||
|
|
||||||
return annotation
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||||
class MessageAnnotationCountApi(Resource):
|
class MessageAnnotationCountApi(Resource):
|
||||||
@api.doc("get_annotation_count")
|
@api.doc("get_annotation_count")
|
||||||
@ -270,6 +228,7 @@ class MessageSuggestedQuestionApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
def get(self, app_model, message_id):
|
def get(self, app_model, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -304,12 +263,12 @@ class MessageApi(Resource):
|
|||||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||||
@api.response(200, "Message retrieved successfully", message_detail_fields)
|
@api.response(200, "Message retrieved successfully", message_detail_fields)
|
||||||
@api.response(404, "Message not found")
|
@api.response(404, "Message not found")
|
||||||
|
@get_app_model
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model
|
|
||||||
@marshal_with(message_detail_fields)
|
@marshal_with(message_detail_fields)
|
||||||
def get(self, app_model, message_id):
|
def get(self, app_model, message_id: str):
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import json
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -14,8 +13,8 @@ from core.tools.tool_manager import ToolManager
|
|||||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||||
from events.app_event import app_model_config_was_updated
|
from events.app_event import app_model_config_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.account import Account
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.model import AppMode, AppModelConfig
|
from models.model import AppMode, AppModelConfig
|
||||||
from services.app_model_config_service import AppModelConfigService
|
from services.app_model_config_service import AppModelConfigService
|
||||||
|
|
||||||
@ -53,16 +52,14 @@ class ModelConfigResource(Resource):
|
|||||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
"""Modify app model config"""
|
"""Modify app model config"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not current_user.has_edit_permission:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
|
||||||
# validate config
|
# validate config
|
||||||
model_configuration = AppModelConfigService.validate_configuration(
|
model_configuration = AppModelConfigService.validate_configuration(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
config=cast(dict, request.json),
|
config=cast(dict, request.json),
|
||||||
app_mode=AppMode.value_of(app_model.mode),
|
app_mode=AppMode.value_of(app_model.mode),
|
||||||
)
|
)
|
||||||
@ -90,16 +87,16 @@ class ModelConfigResource(Resource):
|
|||||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
agent_tool_entity = AgentToolEntity(**tool)
|
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||||
# get tool
|
# get tool
|
||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
agent_tool=agent_tool_entity,
|
agent_tool=agent_tool_entity,
|
||||||
)
|
)
|
||||||
manager = ToolParameterConfigurationManager(
|
manager = ToolParameterConfigurationManager(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
@ -124,7 +121,7 @@ class ModelConfigResource(Resource):
|
|||||||
# encrypt agent tool parameters if it's secret-input
|
# encrypt agent tool parameters if it's secret-input
|
||||||
agent_mode = new_app_model_config.agent_mode_dict
|
agent_mode = new_app_model_config.agent_mode_dict
|
||||||
for tool in agent_mode.get("tools") or []:
|
for tool in agent_mode.get("tools") or []:
|
||||||
agent_tool_entity = AgentToolEntity(**tool)
|
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||||
|
|
||||||
# get tool
|
# get tool
|
||||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||||
@ -133,7 +130,7 @@ class ModelConfigResource(Resource):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
app_id=app_model.id,
|
app_id=app_model.id,
|
||||||
agent_tool=agent_tool_entity,
|
agent_tool=agent_tool_entity,
|
||||||
)
|
)
|
||||||
@ -141,7 +138,7 @@ class ModelConfigResource(Resource):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
manager = ToolParameterConfigurationManager(
|
manager = ToolParameterConfigurationManager(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
tool_runtime=tool_runtime,
|
tool_runtime=tool_runtime,
|
||||||
provider_name=agent_tool_entity.provider_id,
|
provider_name=agent_tool_entity.provider_id,
|
||||||
provider_type=agent_tool_entity.provider_type,
|
provider_type=agent_tool_entity.provider_type,
|
||||||
@ -172,6 +169,8 @@ class ModelConfigResource(Resource):
|
|||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
app_model.app_model_config_id = new_app_model_config.id
|
app_model.app_model_config_id = new_app_model_config.id
|
||||||
|
app_model.updated_by = current_user.id
|
||||||
|
app_model.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||||
|
|||||||
@ -30,8 +30,7 @@ class TraceAppConfigApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -63,9 +62,11 @@ class TraceAppConfigApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, app_id):
|
def post(self, app_id):
|
||||||
"""Create a new trace app configuration"""
|
"""Create a new trace app configuration"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||||
|
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -99,9 +100,11 @@ class TraceAppConfigApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, app_id):
|
def patch(self, app_id):
|
||||||
"""Update an existing trace app configuration"""
|
"""Update an existing trace app configuration"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||||
|
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -129,8 +132,7 @@ class TraceAppConfigApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, app_id):
|
def delete(self, app_id):
|
||||||
"""Delete an existing trace app configuration"""
|
"""Delete an existing trace app configuration"""
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
@ -9,30 +8,36 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_site_fields
|
from fields.app_fields import app_site_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, Site
|
from models import Site
|
||||||
|
|
||||||
|
|
||||||
def parse_app_site_args():
|
def parse_app_site_args():
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("title", type=str, required=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("icon_type", type=str, required=False, location="json")
|
.add_argument("title", type=str, required=False, location="json")
|
||||||
parser.add_argument("icon", type=str, required=False, location="json")
|
.add_argument("icon_type", type=str, required=False, location="json")
|
||||||
parser.add_argument("icon_background", type=str, required=False, location="json")
|
.add_argument("icon", type=str, required=False, location="json")
|
||||||
parser.add_argument("description", type=str, required=False, location="json")
|
.add_argument("icon_background", type=str, required=False, location="json")
|
||||||
parser.add_argument("default_language", type=supported_language, required=False, location="json")
|
.add_argument("description", type=str, required=False, location="json")
|
||||||
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
|
.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||||
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||||
parser.add_argument("customize_domain", type=str, required=False, location="json")
|
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||||
parser.add_argument("copyright", type=str, required=False, location="json")
|
.add_argument("customize_domain", type=str, required=False, location="json")
|
||||||
parser.add_argument("privacy_policy", type=str, required=False, location="json")
|
.add_argument("copyright", type=str, required=False, location="json")
|
||||||
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||||
parser.add_argument(
|
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||||
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
|
.add_argument(
|
||||||
|
"customize_token_strategy",
|
||||||
|
type=str,
|
||||||
|
choices=["must", "allow", "not_allow"],
|
||||||
|
required=False,
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||||
|
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||||
|
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
|
||||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
|
||||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -76,9 +81,10 @@ class AppSite(Resource):
|
|||||||
@marshal_with(app_site_fields)
|
@marshal_with(app_site_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
args = parse_app_site_args()
|
args = parse_app_site_args()
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be editor, admin, or owner
|
# The role of the current user in the ta table must be editor, admin, or owner
|
||||||
if not current_user.is_editor:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||||
@ -107,8 +113,6 @@ class AppSite(Resource):
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
setattr(site, attr_name, value)
|
setattr(site, attr_name, value)
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
site.updated_by = current_user.id
|
site.updated_by = current_user.id
|
||||||
site.updated_at = naive_utc_now()
|
site.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -131,6 +135,8 @@ class AppSiteAccessTokenReset(Resource):
|
|||||||
@marshal_with(app_site_fields)
|
@marshal_with(app_site_fields)
|
||||||
def post(self, app_model):
|
def post(self, app_model):
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
@ -140,8 +146,6 @@ class AppSiteAccessTokenReset(Resource):
|
|||||||
raise NotFound
|
raise NotFound
|
||||||
|
|
||||||
site.code = Site.generate_code(16)
|
site.code = Site.generate_code(16)
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
site.updated_by = current_user.id
|
site.updated_by = current_user.id
|
||||||
site.updated_at = naive_utc_now()
|
site.updated_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|||||||
@ -1,10 +1,7 @@
|
|||||||
from datetime import datetime
|
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
import pytz
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import jsonify
|
from flask import abort, jsonify
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, reqparse
|
from flask_restx import Resource, fields, reqparse
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
@ -12,8 +9,9 @@ from controllers.console.app.wraps import get_app_model
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from libs.datetime_utils import parse_time_range
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import AppMode, Message
|
from models import AppMode, Message
|
||||||
|
|
||||||
|
|
||||||
@ -37,11 +35,13 @@ class DailyMessageStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -52,28 +52,19 @@ FROM
|
|||||||
WHERE
|
WHERE
|
||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
if args["start"]:
|
abort(400, description=str(e))
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
|
||||||
start_datetime = start_datetime.replace(second=0)
|
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
|
if start_datetime_utc:
|
||||||
sql_query += " AND created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
@ -109,15 +100,20 @@ class DailyConversationStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
assert account.timezone is not None
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
|
abort(400, description=str(e))
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
sa.select(
|
sa.select(
|
||||||
@ -127,21 +123,13 @@ class DailyConversationStatistic(Resource):
|
|||||||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||||
)
|
)
|
||||||
.select_from(Message)
|
.select_from(Message)
|
||||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
|
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||||
)
|
)
|
||||||
|
|
||||||
if args["start"]:
|
if start_datetime_utc:
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
|
||||||
start_datetime = start_datetime.replace(second=0)
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||||
|
|
||||||
stmt = stmt.group_by("date").order_by("date")
|
stmt = stmt.group_by("date").order_by("date")
|
||||||
@ -175,11 +163,13 @@ class DailyTerminalsStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -190,28 +180,19 @@ FROM
|
|||||||
WHERE
|
WHERE
|
||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
if args["start"]:
|
abort(400, description=str(e))
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
|
||||||
start_datetime = start_datetime.replace(second=0)
|
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
|
if start_datetime_utc:
|
||||||
sql_query += " AND created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
@ -247,11 +228,13 @@ class DailyTokenCostStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -263,28 +246,19 @@ FROM
|
|||||||
WHERE
|
WHERE
|
||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
if args["start"]:
|
abort(400, description=str(e))
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
|
||||||
start_datetime = start_datetime.replace(second=0)
|
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
|
if start_datetime_utc:
|
||||||
sql_query += " AND created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
@ -322,11 +296,13 @@ class AverageSessionInteractionStatistic(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -345,28 +321,19 @@ FROM
|
|||||||
WHERE
|
WHERE
|
||||||
c.app_id = :app_id
|
c.app_id = :app_id
|
||||||
AND m.invoke_from != :invoke_from"""
|
AND m.invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
if args["start"]:
|
abort(400, description=str(e))
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
|
||||||
start_datetime = start_datetime.replace(second=0)
|
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
|
if start_datetime_utc:
|
||||||
sql_query += " AND c.created_at >= :start"
|
sql_query += " AND c.created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND c.created_at < :end"
|
sql_query += " AND c.created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
@ -413,11 +380,13 @@ class UserSatisfactionRateStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -432,28 +401,19 @@ LEFT JOIN
|
|||||||
WHERE
|
WHERE
|
||||||
m.app_id = :app_id
|
m.app_id = :app_id
|
||||||
AND m.invoke_from != :invoke_from"""
|
AND m.invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
if args["start"]:
|
abort(400, description=str(e))
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
|
||||||
start_datetime = start_datetime.replace(second=0)
|
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
|
if start_datetime_utc:
|
||||||
sql_query += " AND m.created_at >= :start"
|
sql_query += " AND m.created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND m.created_at < :end"
|
sql_query += " AND m.created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
@ -494,11 +454,13 @@ class AverageResponseTimeStatistic(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=AppMode.COMPLETION)
|
@get_app_model(mode=AppMode.COMPLETION)
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -509,28 +471,19 @@ FROM
|
|||||||
WHERE
|
WHERE
|
||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
if args["start"]:
|
abort(400, description=str(e))
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
|
||||||
start_datetime = start_datetime.replace(second=0)
|
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
|
if start_datetime_utc:
|
||||||
sql_query += " AND created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
@ -566,11 +519,13 @@ class TokensPerSecondStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
sql_query = """SELECT
|
||||||
@ -584,28 +539,19 @@ FROM
|
|||||||
WHERE
|
WHERE
|
||||||
app_id = :app_id
|
app_id = :app_id
|
||||||
AND invoke_from != :invoke_from"""
|
AND invoke_from != :invoke_from"""
|
||||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||||
|
assert account.timezone is not None
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
if args["start"]:
|
abort(400, description=str(e))
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
|
||||||
start_datetime = start_datetime.replace(second=0)
|
|
||||||
|
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
|
if start_datetime_utc:
|
||||||
sql_query += " AND created_at >= :start"
|
sql_query += " AND created_at >= :start"
|
||||||
arg_dict["start"] = start_datetime_utc
|
arg_dict["start"] = start_datetime_utc
|
||||||
|
|
||||||
if args["end"]:
|
if end_datetime_utc:
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
sql_query += " AND created_at < :end"
|
||||||
arg_dict["end"] = end_datetime_utc
|
arg_dict["end"] = end_datetime_utc
|
||||||
|
|
||||||
|
|||||||
@ -9,11 +9,10 @@ from sqlalchemy.orm import Session
|
|||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
@ -26,10 +25,10 @@ from factories import file_factory, variable_factory
|
|||||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import App
|
from models import App
|
||||||
from models.account import Account
|
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
@ -70,15 +69,11 @@ class DraftWorkflowApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_fields)
|
@marshal_with(workflow_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Get draft workflow
|
Get draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||||
@ -107,27 +102,38 @@ class DraftWorkflowApi(Resource):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@api.response(200, "Draft workflow synced successfully", workflow_fields)
|
@api.response(
|
||||||
|
200,
|
||||||
|
"Draft workflow synced successfully",
|
||||||
|
api.model(
|
||||||
|
"SyncDraftWorkflowResponse",
|
||||||
|
{
|
||||||
|
"result": fields.String,
|
||||||
|
"hash": fields.String,
|
||||||
|
"updated_at": fields.String,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
@api.response(400, "Invalid workflow configuration")
|
@api.response(400, "Invalid workflow configuration")
|
||||||
@api.response(403, "Permission denied")
|
@api.response(403, "Permission denied")
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Sync draft workflow
|
Sync draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
current_user, _ = current_account_with_tenant()
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
|
||||||
if "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("hash", type=str, required=False, location="json")
|
.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("environment_variables", type=list, required=True, location="json")
|
.add_argument("hash", type=str, required=False, location="json")
|
||||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
.add_argument("environment_variables", type=list, required=True, location="json")
|
||||||
|
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
elif "text/plain" in content_type:
|
elif "text/plain" in content_type:
|
||||||
try:
|
try:
|
||||||
@ -149,10 +155,6 @@ class DraftWorkflowApi(Resource):
|
|||||||
return {"message": "Invalid JSON data"}, 400
|
return {"message": "Invalid JSON data"}, 400
|
||||||
else:
|
else:
|
||||||
abort(415)
|
abort(415)
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -206,24 +208,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Run draft workflow
|
Run draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
current_user, _ = current_account_with_tenant()
|
||||||
assert isinstance(current_user, Account)
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
|
.add_argument("inputs", type=dict, location="json")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("query", type=str, required=True, location="json", default="")
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
.add_argument("files", type=list, location="json")
|
||||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument("files", type=list, location="json")
|
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
)
|
||||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -271,18 +270,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow iteration node
|
Run draft workflow iteration node
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -323,18 +317,13 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow iteration node
|
Run draft workflow iteration node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
raise Forbidden()
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -375,19 +364,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow loop node
|
Run draft workflow loop node
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -428,19 +411,13 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow loop node
|
Run draft workflow loop node
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -480,20 +457,17 @@ class DraftWorkflowRunApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Run draft workflow
|
Run draft workflow
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
raise Forbidden()
|
)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
external_trace_id = get_external_trace_id(request)
|
external_trace_id = get_external_trace_id(request)
|
||||||
@ -526,17 +500,11 @@ class WorkflowTaskStopApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, task_id: str):
|
def post(self, app_model: App, task_id: str):
|
||||||
"""
|
"""
|
||||||
Stop workflow task
|
Stop workflow task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# Stop using both mechanisms for backward compatibility
|
# Stop using both mechanisms for backward compatibility
|
||||||
# Legacy stop flag mechanism (without user check)
|
# Legacy stop flag mechanism (without user check)
|
||||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
@ -568,21 +536,18 @@ class DraftWorkflowNodeRunApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_run_node_execution_fields)
|
@marshal_with(workflow_run_node_execution_fields)
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App, node_id: str):
|
def post(self, app_model: App, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow node
|
Run draft workflow node
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("query", type=str, required=False, location="json", default="")
|
||||||
raise Forbidden()
|
.add_argument("files", type=list, location="json", default=[])
|
||||||
|
)
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
parser.add_argument("query", type=str, required=False, location="json", default="")
|
|
||||||
parser.add_argument("files", type=list, location="json", default=[])
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_inputs = args.get("inputs")
|
user_inputs = args.get("inputs")
|
||||||
@ -622,17 +587,11 @@ class PublishedWorkflowApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_fields)
|
@marshal_with(workflow_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Get published workflow
|
Get published workflow
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# fetch published workflow by app_model
|
# fetch published workflow by app_model
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||||
@ -644,19 +603,17 @@ class PublishedWorkflowApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Publish workflow
|
Publish workflow
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
reqparse.RequestParser()
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||||
raise Forbidden()
|
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||||
|
)
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
|
||||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate name and comment length
|
# Validate name and comment length
|
||||||
@ -675,8 +632,12 @@ class PublishedWorkflowApi(Resource):
|
|||||||
marked_comment=args.marked_comment or "",
|
marked_comment=args.marked_comment or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
app_model.workflow_id = workflow.id
|
# Update app_model within the same session to ensure atomicity
|
||||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
app_model_in_session = session.get(App, app_model.id)
|
||||||
|
if app_model_in_session:
|
||||||
|
app_model_in_session.workflow_id = workflow.id
|
||||||
|
app_model_in_session.updated_by = current_user.id
|
||||||
|
app_model_in_session.updated_at = naive_utc_now()
|
||||||
|
|
||||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||||
|
|
||||||
@ -698,17 +659,11 @@ class DefaultBlockConfigsApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# Get default block configs
|
# Get default block configs
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
return workflow_service.get_default_block_configs()
|
return workflow_service.get_default_block_configs()
|
||||||
@ -725,18 +680,12 @@ class DefaultBlockConfigApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App, block_type: str):
|
def get(self, app_model: App, block_type: str):
|
||||||
"""
|
"""
|
||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("q", type=str, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
q = args.get("q")
|
q = args.get("q")
|
||||||
@ -765,24 +714,23 @@ class ConvertToWorkflowApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
|
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
|
||||||
|
@edit_permission_required
|
||||||
def post(self, app_model: App):
|
def post(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Convert basic mode of chatbot app to workflow mode
|
Convert basic mode of chatbot app to workflow mode
|
||||||
Convert expert mode of chatbot app to workflow mode
|
Convert expert mode of chatbot app to workflow mode
|
||||||
Convert Completion App to Workflow App
|
Convert Completion App to Workflow App
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if request.data:
|
if request.data:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
else:
|
else:
|
||||||
args = {}
|
args = {}
|
||||||
@ -797,24 +745,6 @@ class ConvertToWorkflowApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/config")
|
|
||||||
class WorkflowConfigApi(Resource):
|
|
||||||
"""Resource for workflow configuration."""
|
|
||||||
|
|
||||||
@api.doc("get_workflow_config")
|
|
||||||
@api.doc(description="Get workflow configuration")
|
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
|
||||||
@api.response(200, "Workflow configuration retrieved successfully")
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
|
||||||
def get(self, app_model: App):
|
|
||||||
return {
|
|
||||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||||
class PublishedAllWorkflowApi(Resource):
|
class PublishedAllWorkflowApi(Resource):
|
||||||
@api.doc("get_all_published_workflows")
|
@api.doc("get_all_published_workflows")
|
||||||
@ -826,21 +756,20 @@ class PublishedAllWorkflowApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_pagination_fields)
|
@marshal_with(workflow_pagination_fields)
|
||||||
|
@edit_permission_required
|
||||||
def get(self, app_model: App):
|
def get(self, app_model: App):
|
||||||
"""
|
"""
|
||||||
Get published workflows
|
Get published workflows
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
parser = (
|
||||||
raise Forbidden()
|
reqparse.RequestParser()
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||||
raise Forbidden()
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
.add_argument("user_id", type=str, required=False, location="args")
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
)
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
parser.add_argument("user_id", type=str, required=False, location="args")
|
|
||||||
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
page = int(args.get("page", 1))
|
page = int(args.get("page", 1))
|
||||||
limit = int(args.get("limit", 10))
|
limit = int(args.get("limit", 10))
|
||||||
@ -893,19 +822,17 @@ class WorkflowByIdApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
@marshal_with(workflow_fields)
|
@marshal_with(workflow_fields)
|
||||||
|
@edit_permission_required
|
||||||
def patch(self, app_model: App, workflow_id: str):
|
def patch(self, app_model: App, workflow_id: str):
|
||||||
"""
|
"""
|
||||||
Update workflow attributes
|
Update workflow attributes
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
# Check permission
|
reqparse.RequestParser()
|
||||||
if not current_user.has_edit_permission:
|
.add_argument("marked_name", type=str, required=False, location="json")
|
||||||
raise Forbidden()
|
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
|
||||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate name and comment length
|
# Validate name and comment length
|
||||||
@ -948,16 +875,11 @@ class WorkflowByIdApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@edit_permission_required
|
||||||
def delete(self, app_model: App, workflow_id: str):
|
def delete(self, app_model: App, workflow_id: str):
|
||||||
"""
|
"""
|
||||||
Delete workflow
|
Delete workflow
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise Forbidden()
|
|
||||||
# Check permission
|
|
||||||
if not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
workflow_service = WorkflowService()
|
||||||
|
|
||||||
# Create a session and manage the transaction
|
# Create a session and manage the transaction
|
||||||
|
|||||||
@ -42,33 +42,35 @@ class WorkflowAppLogApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get workflow app logs
|
Get workflow app logs
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("keyword", type=str, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("keyword", type=str, location="args")
|
||||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
.add_argument(
|
||||||
|
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"created_by_end_user_session_id",
|
||||||
|
type=str,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
.add_argument(
|
||||||
|
"created_by_account",
|
||||||
|
type=str,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"created_by_end_user_session_id",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"created_by_account",
|
|
||||||
type=str,
|
|
||||||
location="args",
|
|
||||||
required=False,
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
|
||||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||||
|
|||||||
@ -22,8 +22,7 @@ from extensions.ext_database import db
|
|||||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
from factories.variable_factory import build_segment_with_type
|
from factories.variable_factory import build_segment_with_type
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models import App, AppMode
|
from models import Account, App, AppMode
|
||||||
from models.account import Account
|
|
||||||
from models.workflow import WorkflowDraftVariable
|
from models.workflow import WorkflowDraftVariable
|
||||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||||
from services.workflow_service import WorkflowService
|
from services.workflow_service import WorkflowService
|
||||||
@ -58,16 +57,18 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
|||||||
|
|
||||||
|
|
||||||
def _create_pagination_parser():
|
def _create_pagination_parser():
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"page",
|
.add_argument(
|
||||||
type=inputs.int_range(1, 100_000),
|
"page",
|
||||||
required=False,
|
type=inputs.int_range(1, 100_000),
|
||||||
default=1,
|
required=False,
|
||||||
location="args",
|
default=1,
|
||||||
help="the page of data requested",
|
location="args",
|
||||||
|
help="the page of data requested",
|
||||||
|
)
|
||||||
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
)
|
)
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -320,10 +321,11 @@ class VariableApi(Resource):
|
|||||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
# }
|
# }
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
# Parse 'value' field as-is to maintain its original data structure
|
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
draft_var_srv = WorkflowDraftVariableService(
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
session=db.session(),
|
session=db.session(),
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
|
|
||||||
@ -9,15 +8,81 @@ from controllers.console.app.wraps import get_app_model
|
|||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.workflow_run_fields import (
|
from fields.workflow_run_fields import (
|
||||||
advanced_chat_workflow_run_pagination_fields,
|
advanced_chat_workflow_run_pagination_fields,
|
||||||
|
workflow_run_count_fields,
|
||||||
workflow_run_detail_fields,
|
workflow_run_detail_fields,
|
||||||
workflow_run_node_execution_list_fields,
|
workflow_run_node_execution_list_fields,
|
||||||
workflow_run_pagination_fields,
|
workflow_run_pagination_fields,
|
||||||
)
|
)
|
||||||
|
from libs.custom_inputs import time_duration
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import current_user, login_required
|
||||||
from models import Account, App, AppMode, EndUser
|
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
|
||||||
from services.workflow_run_service import WorkflowRunService
|
from services.workflow_run_service import WorkflowRunService
|
||||||
|
|
||||||
|
# Workflow run status choices for filtering
|
||||||
|
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_workflow_run_list_args():
|
||||||
|
"""
|
||||||
|
Parse common arguments for workflow run list endpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||||
|
"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||||
|
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
parser.add_argument(
|
||||||
|
"status",
|
||||||
|
type=str,
|
||||||
|
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"triggered_from",
|
||||||
|
type=str,
|
||||||
|
choices=["debugging", "app-run"],
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
help="Filter by trigger source: debugging or app-run",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_workflow_run_count_args():
|
||||||
|
"""
|
||||||
|
Parse common arguments for workflow run count endpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed arguments containing status, time_range, and triggered_from filters
|
||||||
|
"""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"status",
|
||||||
|
type=str,
|
||||||
|
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"time_range",
|
||||||
|
type=time_duration,
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"triggered_from",
|
||||||
|
type=str,
|
||||||
|
choices=["debugging", "app-run"],
|
||||||
|
location="args",
|
||||||
|
required=False,
|
||||||
|
help="Filter by trigger source: debugging or app-run",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||||
class AdvancedChatAppWorkflowRunListApi(Resource):
|
class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||||
@ -25,6 +90,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||||||
@api.doc(description="Get advanced chat workflow run list")
|
@api.doc(description="Get advanced chat workflow run list")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||||
|
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||||
|
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||||
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
|
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -35,13 +102,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get advanced chat app workflow run list
|
Get advanced chat app workflow run list
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
args = _parse_workflow_run_list_args()
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
# Default to DEBUGGING if not specified
|
||||||
args = parser.parse_args()
|
triggered_from = (
|
||||||
|
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||||
|
if args.get("triggered_from")
|
||||||
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
|
)
|
||||||
|
|
||||||
workflow_run_service = WorkflowRunService()
|
workflow_run_service = WorkflowRunService()
|
||||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
|
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
|
||||||
|
app_model=app_model, args=args, triggered_from=triggered_from
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||||
|
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||||
|
@api.doc("get_advanced_chat_workflow_runs_count")
|
||||||
|
@api.doc(description="Get advanced chat workflow runs count statistics")
|
||||||
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
|
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||||
|
@api.doc(
|
||||||
|
params={
|
||||||
|
"time_range": (
|
||||||
|
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||||
|
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||||
|
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||||
|
@marshal_with(workflow_run_count_fields)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get advanced chat workflow runs count statistics
|
||||||
|
"""
|
||||||
|
args = _parse_workflow_run_count_args()
|
||||||
|
|
||||||
|
# Default to DEBUGGING if not specified
|
||||||
|
triggered_from = (
|
||||||
|
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||||
|
if args.get("triggered_from")
|
||||||
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_run_service = WorkflowRunService()
|
||||||
|
result = workflow_run_service.get_workflow_runs_count(
|
||||||
|
app_model=app_model,
|
||||||
|
status=args.get("status"),
|
||||||
|
time_range=args.get("time_range"),
|
||||||
|
triggered_from=triggered_from,
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -52,6 +170,8 @@ class WorkflowRunListApi(Resource):
|
|||||||
@api.doc(description="Get workflow run list")
|
@api.doc(description="Get workflow run list")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||||
|
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||||
|
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||||
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
|
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -62,13 +182,64 @@ class WorkflowRunListApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get workflow run list
|
Get workflow run list
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
args = _parse_workflow_run_list_args()
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||||
args = parser.parse_args()
|
triggered_from = (
|
||||||
|
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||||
|
if args.get("triggered_from")
|
||||||
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
|
)
|
||||||
|
|
||||||
workflow_run_service = WorkflowRunService()
|
workflow_run_service = WorkflowRunService()
|
||||||
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
|
result = workflow_run_service.get_paginate_workflow_runs(
|
||||||
|
app_model=app_model, args=args, triggered_from=triggered_from
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
|
||||||
|
class WorkflowRunCountApi(Resource):
|
||||||
|
@api.doc("get_workflow_runs_count")
|
||||||
|
@api.doc(description="Get workflow runs count statistics")
|
||||||
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
|
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||||
|
@api.doc(
|
||||||
|
params={
|
||||||
|
"time_range": (
|
||||||
|
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||||
|
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||||
|
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
@marshal_with(workflow_run_count_fields)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get workflow runs count statistics
|
||||||
|
"""
|
||||||
|
args = _parse_workflow_run_count_args()
|
||||||
|
|
||||||
|
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||||
|
triggered_from = (
|
||||||
|
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||||
|
if args.get("triggered_from")
|
||||||
|
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_run_service = WorkflowRunService()
|
||||||
|
result = workflow_run_service.get_workflow_runs_count(
|
||||||
|
app_model=app_model,
|
||||||
|
status=args.get("status"),
|
||||||
|
time_range=args.get("time_range"),
|
||||||
|
triggered_from=triggered_from,
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@ -1,24 +1,26 @@
|
|||||||
from datetime import datetime
|
from flask import abort, jsonify
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
import pytz
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from flask import jsonify
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from libs.datetime_utils import parse_time_range
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||||
class WorkflowDailyRunsStatistic(Resource):
|
class WorkflowDailyRunsStatistic(Resource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
@api.doc("get_workflow_daily_runs_statistic")
|
@api.doc("get_workflow_daily_runs_statistic")
|
||||||
@api.doc(description="Get workflow daily runs statistics")
|
@api.doc(description="Get workflow daily runs statistics")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@ -29,64 +31,41 @@ class WorkflowDailyRunsStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
assert account.timezone is not None
|
||||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
|
||||||
COUNT(id) AS runs
|
|
||||||
FROM
|
|
||||||
workflow_runs
|
|
||||||
WHERE
|
|
||||||
app_id = :app_id
|
|
||||||
AND triggered_from = :triggered_from"""
|
|
||||||
arg_dict = {
|
|
||||||
"tz": account.timezone,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
|
||||||
}
|
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
|
abort(400, description=str(e))
|
||||||
|
|
||||||
if args["start"]:
|
response_data = self._workflow_run_repo.get_daily_runs_statistics(
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
tenant_id=app_model.tenant_id,
|
||||||
start_datetime = start_datetime.replace(second=0)
|
app_id=app_model.id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_date=start_date,
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
end_date=end_date,
|
||||||
|
timezone=account.timezone,
|
||||||
sql_query += " AND created_at >= :start"
|
)
|
||||||
arg_dict["start"] = start_datetime_utc
|
|
||||||
|
|
||||||
if args["end"]:
|
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
|
||||||
arg_dict["end"] = end_datetime_utc
|
|
||||||
|
|
||||||
sql_query += " GROUP BY date ORDER BY date"
|
|
||||||
|
|
||||||
response_data = []
|
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
|
||||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
|
||||||
for i in rs:
|
|
||||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
|
||||||
|
|
||||||
return jsonify({"data": response_data})
|
return jsonify({"data": response_data})
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||||
class WorkflowDailyTerminalsStatistic(Resource):
|
class WorkflowDailyTerminalsStatistic(Resource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
@api.doc("get_workflow_daily_terminals_statistic")
|
@api.doc("get_workflow_daily_terminals_statistic")
|
||||||
@api.doc(description="Get workflow daily terminals statistics")
|
@api.doc(description="Get workflow daily terminals statistics")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@ -97,64 +76,41 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
assert account.timezone is not None
|
||||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
|
||||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
|
||||||
FROM
|
|
||||||
workflow_runs
|
|
||||||
WHERE
|
|
||||||
app_id = :app_id
|
|
||||||
AND triggered_from = :triggered_from"""
|
|
||||||
arg_dict = {
|
|
||||||
"tz": account.timezone,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
|
||||||
}
|
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
|
abort(400, description=str(e))
|
||||||
|
|
||||||
if args["start"]:
|
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
tenant_id=app_model.tenant_id,
|
||||||
start_datetime = start_datetime.replace(second=0)
|
app_id=app_model.id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_date=start_date,
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
end_date=end_date,
|
||||||
|
timezone=account.timezone,
|
||||||
sql_query += " AND created_at >= :start"
|
)
|
||||||
arg_dict["start"] = start_datetime_utc
|
|
||||||
|
|
||||||
if args["end"]:
|
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
|
||||||
arg_dict["end"] = end_datetime_utc
|
|
||||||
|
|
||||||
sql_query += " GROUP BY date ORDER BY date"
|
|
||||||
|
|
||||||
response_data = []
|
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
|
||||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
|
||||||
for i in rs:
|
|
||||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
|
||||||
|
|
||||||
return jsonify({"data": response_data})
|
return jsonify({"data": response_data})
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||||
class WorkflowDailyTokenCostStatistic(Resource):
|
class WorkflowDailyTokenCostStatistic(Resource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
@api.doc("get_workflow_daily_token_cost_statistic")
|
@api.doc("get_workflow_daily_token_cost_statistic")
|
||||||
@api.doc(description="Get workflow daily token cost statistics")
|
@api.doc(description="Get workflow daily token cost statistics")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@ -165,69 +121,41 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
assert account.timezone is not None
|
||||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
|
||||||
SUM(workflow_runs.total_tokens) AS token_count
|
|
||||||
FROM
|
|
||||||
workflow_runs
|
|
||||||
WHERE
|
|
||||||
app_id = :app_id
|
|
||||||
AND triggered_from = :triggered_from"""
|
|
||||||
arg_dict = {
|
|
||||||
"tz": account.timezone,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
|
||||||
}
|
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
|
abort(400, description=str(e))
|
||||||
|
|
||||||
if args["start"]:
|
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
tenant_id=app_model.tenant_id,
|
||||||
start_datetime = start_datetime.replace(second=0)
|
app_id=app_model.id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_date=start_date,
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
end_date=end_date,
|
||||||
|
timezone=account.timezone,
|
||||||
sql_query += " AND created_at >= :start"
|
)
|
||||||
arg_dict["start"] = start_datetime_utc
|
|
||||||
|
|
||||||
if args["end"]:
|
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query += " AND created_at < :end"
|
|
||||||
arg_dict["end"] = end_datetime_utc
|
|
||||||
|
|
||||||
sql_query += " GROUP BY date ORDER BY date"
|
|
||||||
|
|
||||||
response_data = []
|
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
|
||||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
|
||||||
for i in rs:
|
|
||||||
response_data.append(
|
|
||||||
{
|
|
||||||
"date": str(i.date),
|
|
||||||
"token_count": i.token_count,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify({"data": response_data})
|
return jsonify({"data": response_data})
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
|
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
|
||||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
|
||||||
@api.doc("get_workflow_average_app_interaction_statistic")
|
@api.doc("get_workflow_average_app_interaction_statistic")
|
||||||
@api.doc(description="Get workflow average app interaction statistics")
|
@api.doc(description="Get workflow average app interaction statistics")
|
||||||
@api.doc(params={"app_id": "Application ID"})
|
@api.doc(params={"app_id": "Application ID"})
|
||||||
@ -238,74 +166,29 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||||
def get(self, app_model):
|
def get(self, app_model):
|
||||||
account = current_user
|
account, _ = current_account_with_tenant()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sql_query = """SELECT
|
assert account.timezone is not None
|
||||||
AVG(sub.interactions) AS interactions,
|
|
||||||
sub.date
|
|
||||||
FROM
|
|
||||||
(
|
|
||||||
SELECT
|
|
||||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
|
||||||
c.created_by,
|
|
||||||
COUNT(c.id) AS interactions
|
|
||||||
FROM
|
|
||||||
workflow_runs c
|
|
||||||
WHERE
|
|
||||||
c.app_id = :app_id
|
|
||||||
AND c.triggered_from = :triggered_from
|
|
||||||
{{start}}
|
|
||||||
{{end}}
|
|
||||||
GROUP BY
|
|
||||||
date, c.created_by
|
|
||||||
) sub
|
|
||||||
GROUP BY
|
|
||||||
sub.date"""
|
|
||||||
arg_dict = {
|
|
||||||
"tz": account.timezone,
|
|
||||||
"app_id": app_model.id,
|
|
||||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
|
||||||
}
|
|
||||||
|
|
||||||
timezone = pytz.timezone(account.timezone)
|
try:
|
||||||
utc_timezone = pytz.utc
|
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||||
|
except ValueError as e:
|
||||||
|
abort(400, description=str(e))
|
||||||
|
|
||||||
if args["start"]:
|
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
|
||||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
tenant_id=app_model.tenant_id,
|
||||||
start_datetime = start_datetime.replace(second=0)
|
app_id=app_model.id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||||
start_datetime_timezone = timezone.localize(start_datetime)
|
start_date=start_date,
|
||||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
end_date=end_date,
|
||||||
|
timezone=account.timezone,
|
||||||
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
)
|
||||||
arg_dict["start"] = start_datetime_utc
|
|
||||||
else:
|
|
||||||
sql_query = sql_query.replace("{{start}}", "")
|
|
||||||
|
|
||||||
if args["end"]:
|
|
||||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
|
||||||
end_datetime = end_datetime.replace(second=0)
|
|
||||||
|
|
||||||
end_datetime_timezone = timezone.localize(end_datetime)
|
|
||||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
|
||||||
|
|
||||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
|
||||||
arg_dict["end"] = end_datetime_utc
|
|
||||||
else:
|
|
||||||
sql_query = sql_query.replace("{{end}}", "")
|
|
||||||
|
|
||||||
response_data = []
|
|
||||||
|
|
||||||
with db.engine.begin() as conn:
|
|
||||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
|
||||||
for i in rs:
|
|
||||||
response_data.append(
|
|
||||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify({"data": response_data})
|
return jsonify({"data": response_data})
|
||||||
|
|||||||
@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
|
|||||||
|
|
||||||
from controllers.console.app.error import AppNotFoundError
|
from controllers.console.app.error import AppNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models import App, AppMode
|
from models import App, AppMode
|
||||||
from models.account import Account
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
P1 = ParamSpec("P1")
|
||||||
|
R1 = TypeVar("R1")
|
||||||
|
|
||||||
|
|
||||||
def _load_app_model(app_id: str) -> App | None:
|
def _load_app_model(app_id: str) -> App | None:
|
||||||
assert isinstance(current_user, Account)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
app_model = (
|
app_model = (
|
||||||
db.session.query(App)
|
db.session.query(App)
|
||||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
return app_model
|
return app_model
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||||
def decorator(view_func: Callable[P, R]):
|
def decorator(view_func: Callable[P1, R1]):
|
||||||
@wraps(view_func)
|
@wraps(view_func)
|
||||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
|
||||||
if not kwargs.get("app_id"):
|
if not kwargs.get("app_id"):
|
||||||
raise ValueError("missing app_id in path parameters")
|
raise ValueError("missing app_id in path parameters")
|
||||||
|
|
||||||
|
|||||||
@ -7,18 +7,14 @@ from controllers.console.error import AlreadyActivateError
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||||
from models.account import AccountStatus
|
from models import AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService
|
from services.account_service import AccountService, RegisterService
|
||||||
|
|
||||||
active_check_parser = reqparse.RequestParser()
|
active_check_parser = (
|
||||||
active_check_parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
|
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
||||||
)
|
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
||||||
active_check_parser.add_argument(
|
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
||||||
"email", type=email, required=False, nullable=True, location="args", help="Email address"
|
|
||||||
)
|
|
||||||
active_check_parser.add_argument(
|
|
||||||
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -60,15 +56,15 @@ class ActivateCheckApi(Resource):
|
|||||||
return {"is_valid": False}
|
return {"is_valid": False}
|
||||||
|
|
||||||
|
|
||||||
active_parser = reqparse.RequestParser()
|
active_parser = (
|
||||||
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||||
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||||
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
active_parser.add_argument(
|
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/activate")
|
@console_ns.route("/activate")
|
||||||
@ -103,7 +99,7 @@ class ActivateApi(Resource):
|
|||||||
account.interface_language = args["interface_language"]
|
account.interface_language = args["interface_language"]
|
||||||
account.timezone = args["timezone"]
|
account.timezone = args["timezone"]
|
||||||
account.interface_theme = "light"
|
account.interface_theme = "light"
|
||||||
account.status = AccountStatus.ACTIVE.value
|
account.status = AccountStatus.ACTIVE
|
||||||
account.initialized_at = naive_utc_now()
|
account.initialized_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
|||||||
@ -1,21 +1,22 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||||
|
|
||||||
from ..wraps import account_initialization_required, setup_required
|
from ..wraps import account_initialization_required, setup_required
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/api-key-auth/data-source")
|
||||||
class ApiKeyAuthDataSource(Resource):
|
class ApiKeyAuthDataSource(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||||
if data_source_api_key_bindings:
|
if data_source_api_key_bindings:
|
||||||
return {
|
return {
|
||||||
"sources": [
|
"sources": [
|
||||||
@ -33,41 +34,44 @@ class ApiKeyAuthDataSource(Resource):
|
|||||||
return {"sources": []}
|
return {"sources": []}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/api-key-auth/data-source/binding")
|
||||||
class ApiKeyAuthDataSourceBinding(Resource):
|
class ApiKeyAuthDataSourceBinding(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||||
try:
|
try:
|
||||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ApiKeyAuthFailedError(str(e))
|
raise ApiKeyAuthFailedError(str(e))
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/api-key-auth/data-source/<uuid:binding_id>")
|
||||||
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, binding_id):
|
def delete(self, binding_id):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
|
|
||||||
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
|
|
||||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")
|
|
||||||
|
|||||||
@ -2,13 +2,12 @@ import logging
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask import current_app, redirect, request
|
from flask import current_app, redirect, request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from libs.oauth_data_source import NotionOAuth
|
from libs.oauth_data_source import NotionOAuth
|
||||||
|
|
||||||
from ..wraps import account_initialization_required, setup_required
|
from ..wraps import account_initialization_required, setup_required
|
||||||
@ -45,6 +44,7 @@ class OAuthDataSource(Resource):
|
|||||||
@api.response(403, "Admin privileges required")
|
@api.response(403, "Admin privileges required")
|
||||||
def get(self, provider: str):
|
def get(self, provider: str):
|
||||||
# The role of the current user in the table must be admin or owner
|
# The role of the current user in the table must be admin or owner
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.auth.error import (
|
from controllers.console.auth.error import (
|
||||||
EmailAlreadyInUseError,
|
EmailAlreadyInUseError,
|
||||||
EmailCodeError,
|
EmailCodeError,
|
||||||
@ -19,20 +19,23 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/email-register/send-email")
|
||||||
class EmailRegisterSendEmailApi(Resource):
|
class EmailRegisterSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
@ -52,15 +55,18 @@ class EmailRegisterSendEmailApi(Resource):
|
|||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/email-register/validity")
|
||||||
class EmailRegisterCheckApi(Resource):
|
class EmailRegisterCheckApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=True, location="json")
|
.add_argument("email", type=str, required=True, location="json")
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args["email"]
|
||||||
@ -92,15 +98,18 @@ class EmailRegisterCheckApi(Resource):
|
|||||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/email-register")
|
||||||
class EmailRegisterResetApi(Resource):
|
class EmailRegisterResetApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
@email_register_enabled
|
@email_register_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
@ -148,8 +157,3 @@ class EmailRegisterResetApi(Resource):
|
|||||||
raise AccountInFreezeError()
|
raise AccountInFreezeError()
|
||||||
|
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email")
|
|
||||||
api.add_resource(EmailRegisterCheckApi, "/email-register/validity")
|
|
||||||
api.add_resource(EmailRegisterResetApi, "/email-register")
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password, valid_password
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
@ -54,9 +54,11 @@ class ForgotPasswordSendEmailApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
@ -111,10 +113,12 @@ class ForgotPasswordCheckApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=True, location="json")
|
.add_argument("email", type=str, required=True, location="json")
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args["email"]
|
||||||
@ -169,10 +173,12 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate passwords match
|
# Validate passwords match
|
||||||
@ -221,8 +227,3 @@ class ForgotPasswordResetApi(Resource):
|
|||||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||||
account.current_tenant = tenant
|
account.current_tenant = tenant
|
||||||
tenant_was_created.send(tenant)
|
tenant_was_created.send(tenant)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
|
||||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
|
||||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
|
||||||
|
|||||||
@ -1,13 +1,11 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
import flask_login
|
import flask_login
|
||||||
from flask import request
|
from flask import make_response, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import languages
|
from constants.languages import get_valid_language
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.auth.error import (
|
from controllers.console.auth.error import (
|
||||||
AuthenticationFailedError,
|
AuthenticationFailedError,
|
||||||
EmailCodeError,
|
EmailCodeError,
|
||||||
@ -26,7 +24,16 @@ from controllers.console.error import (
|
|||||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from models.account import Account
|
from libs.login import current_account_with_tenant
|
||||||
|
from libs.token import (
|
||||||
|
clear_access_token_from_cookie,
|
||||||
|
clear_csrf_token_from_cookie,
|
||||||
|
clear_refresh_token_from_cookie,
|
||||||
|
extract_refresh_token,
|
||||||
|
set_access_token_to_cookie,
|
||||||
|
set_csrf_token_to_cookie,
|
||||||
|
set_refresh_token_to_cookie,
|
||||||
|
)
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.errors.account import AccountRegisterError
|
from services.errors.account import AccountRegisterError
|
||||||
@ -34,6 +41,7 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
|
|||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/login")
|
||||||
class LoginApi(Resource):
|
class LoginApi(Resource):
|
||||||
"""Resource for user login."""
|
"""Resource for user login."""
|
||||||
|
|
||||||
@ -41,11 +49,13 @@ class LoginApi(Resource):
|
|||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
"""Authenticate user and login."""
|
"""Authenticate user and login."""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("password", type=str, required=True, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
.add_argument("password", type=str, required=True, location="json")
|
||||||
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||||
|
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||||
@ -88,27 +98,48 @@ class LoginApi(Resource):
|
|||||||
|
|
||||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||||
AccountService.reset_login_error_rate_limit(args["email"])
|
AccountService.reset_login_error_rate_limit(args["email"])
|
||||||
return {"result": "success", "data": token_pair.model_dump()}
|
|
||||||
|
# Create response with cookies instead of returning tokens in body
|
||||||
|
response = make_response({"result": "success"})
|
||||||
|
|
||||||
|
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||||
|
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||||
|
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/logout")
|
||||||
class LogoutApi(Resource):
|
class LogoutApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def get(self):
|
def post(self):
|
||||||
account = cast(Account, flask_login.current_user)
|
current_user, _ = current_account_with_tenant()
|
||||||
|
account = current_user
|
||||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||||
return {"result": "success"}
|
response = make_response({"result": "success"})
|
||||||
AccountService.logout(account=account)
|
else:
|
||||||
flask_login.logout_user()
|
AccountService.logout(account=account)
|
||||||
return {"result": "success"}
|
flask_login.logout_user()
|
||||||
|
response = make_response({"result": "success"})
|
||||||
|
|
||||||
|
# Clear cookies on logout
|
||||||
|
clear_access_token_from_cookie(response)
|
||||||
|
clear_refresh_token_from_cookie(response)
|
||||||
|
clear_csrf_token_from_cookie(response)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/reset-password")
|
||||||
class ResetPasswordSendEmailApi(Resource):
|
class ResetPasswordSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@email_password_login_enabled
|
@email_password_login_enabled
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||||
@ -130,12 +161,15 @@ class ResetPasswordSendEmailApi(Resource):
|
|||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/email-code-login")
|
||||||
class EmailCodeLoginSendEmailApi(Resource):
|
class EmailCodeLoginSendEmailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("language", type=str, required=False, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
@ -162,16 +196,21 @@ class EmailCodeLoginSendEmailApi(Resource):
|
|||||||
return {"result": "success", "data": token}
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/email-code-login/validity")
|
||||||
class EmailCodeLoginApi(Resource):
|
class EmailCodeLoginApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=True, location="json")
|
.add_argument("email", type=str, required=True, location="json")
|
||||||
parser.add_argument("token", type=str, required=True, location="json")
|
.add_argument("code", type=str, required=True, location="json")
|
||||||
|
.add_argument("token", type=str, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
user_email = args["email"]
|
user_email = args["email"]
|
||||||
|
language = args["language"]
|
||||||
|
|
||||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||||
if token_data is None:
|
if token_data is None:
|
||||||
@ -205,7 +244,9 @@ class EmailCodeLoginApi(Resource):
|
|||||||
if account is None:
|
if account is None:
|
||||||
try:
|
try:
|
||||||
account = AccountService.create_account_and_tenant(
|
account = AccountService.create_account_and_tenant(
|
||||||
email=user_email, name=user_email, interface_language=languages[0]
|
email=user_email,
|
||||||
|
name=user_email,
|
||||||
|
interface_language=get_valid_language(language),
|
||||||
)
|
)
|
||||||
except WorkSpaceNotAllowedCreateError:
|
except WorkSpaceNotAllowedCreateError:
|
||||||
raise NotAllowedCreateWorkspace()
|
raise NotAllowedCreateWorkspace()
|
||||||
@ -215,25 +256,36 @@ class EmailCodeLoginApi(Resource):
|
|||||||
raise WorkspacesLimitExceeded()
|
raise WorkspacesLimitExceeded()
|
||||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||||
AccountService.reset_login_error_rate_limit(args["email"])
|
AccountService.reset_login_error_rate_limit(args["email"])
|
||||||
return {"result": "success", "data": token_pair.model_dump()}
|
|
||||||
|
# Create response with cookies instead of returning tokens in body
|
||||||
|
response = make_response({"result": "success"})
|
||||||
|
|
||||||
|
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||||
|
# Set HTTP-only secure cookies for tokens
|
||||||
|
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||||
|
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/refresh-token")
|
||||||
class RefreshTokenApi(Resource):
|
class RefreshTokenApi(Resource):
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
# Get refresh token from cookie instead of request body
|
||||||
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
refresh_token = extract_refresh_token(request)
|
||||||
args = parser.parse_args()
|
|
||||||
|
if not refresh_token:
|
||||||
|
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
new_token_pair = AccountService.refresh_token(refresh_token)
|
||||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
|
||||||
|
# Create response with new cookies
|
||||||
|
response = make_response({"result": "success"})
|
||||||
|
|
||||||
|
# Update cookies with new tokens
|
||||||
|
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
|
||||||
|
set_access_token_to_cookie(request, response, new_token_pair.access_token)
|
||||||
|
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"result": "fail", "data": str(e)}, 401
|
return {"result": "fail", "message": str(e)}, 401
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(LoginApi, "/login")
|
|
||||||
api.add_resource(LogoutApi, "/logout")
|
|
||||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
|
||||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
|
||||||
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
|
|
||||||
api.add_resource(RefreshTokenApi, "/refresh-token")
|
|
||||||
|
|||||||
@ -14,8 +14,12 @@ from extensions.ext_database import db
|
|||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||||
from models import Account
|
from libs.token import (
|
||||||
from models.account import AccountStatus
|
set_access_token_to_cookie,
|
||||||
|
set_csrf_token_to_cookie,
|
||||||
|
set_refresh_token_to_cookie,
|
||||||
|
)
|
||||||
|
from models import Account, AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||||
@ -130,11 +134,11 @@ class OAuthCallback(Resource):
|
|||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
|
||||||
|
|
||||||
# Check account status
|
# Check account status
|
||||||
if account.status == AccountStatus.BANNED.value:
|
if account.status == AccountStatus.BANNED:
|
||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||||
|
|
||||||
if account.status == AccountStatus.PENDING.value:
|
if account.status == AccountStatus.PENDING:
|
||||||
account.status = AccountStatus.ACTIVE.value
|
account.status = AccountStatus.ACTIVE
|
||||||
account.initialized_at = naive_utc_now()
|
account.initialized_at = naive_utc_now()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@ -153,9 +157,12 @@ class OAuthCallback(Resource):
|
|||||||
ip_address=extract_remote_ip(request),
|
ip_address=extract_remote_ip(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
return redirect(
|
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||||
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
|
||||||
)
|
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||||
|
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||||
|
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:
|
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:
|
||||||
|
|||||||
@ -1,20 +1,19 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
from typing import Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
import flask_login
|
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import BadRequest, NotFound
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from models.model import OAuthProviderApp
|
from models.model import OAuthProviderApp
|
||||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||||
|
|
||||||
from .. import api
|
from .. import console_ns
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
@ -24,8 +23,7 @@ T = TypeVar("T")
|
|||||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
||||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
|
||||||
parsed_args = parser.parse_args()
|
parsed_args = parser.parse_args()
|
||||||
client_id = parsed_args.get("client_id")
|
client_id = parsed_args.get("client_id")
|
||||||
if not client_id:
|
if not client_id:
|
||||||
@ -86,12 +84,12 @@ def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProvid
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/oauth/provider")
|
||||||
class OAuthServerAppApi(Resource):
|
class OAuthServerAppApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
||||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
|
||||||
parsed_args = parser.parse_args()
|
parsed_args = parser.parse_args()
|
||||||
redirect_uri = parsed_args.get("redirect_uri")
|
redirect_uri = parsed_args.get("redirect_uri")
|
||||||
|
|
||||||
@ -108,13 +106,15 @@ class OAuthServerAppApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/oauth/provider/authorize")
|
||||||
class OAuthServerUserAuthorizeApi(Resource):
|
class OAuthServerUserAuthorizeApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
account = cast(Account, flask_login.current_user)
|
current_user, _ = current_account_with_tenant()
|
||||||
|
account = current_user
|
||||||
user_account_id = account.id
|
user_account_id = account.id
|
||||||
|
|
||||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||||
@ -125,16 +125,19 @@ class OAuthServerUserAuthorizeApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/oauth/provider/token")
|
||||||
class OAuthServerUserTokenApi(Resource):
|
class OAuthServerUserTokenApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("code", type=str, required=False, location="json")
|
.add_argument("grant_type", type=str, required=True, location="json")
|
||||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
.add_argument("code", type=str, required=False, location="json")
|
||||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
.add_argument("client_secret", type=str, required=False, location="json")
|
||||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||||
|
.add_argument("refresh_token", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
parsed_args = parser.parse_args()
|
parsed_args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -180,6 +183,7 @@ class OAuthServerUserTokenApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/oauth/provider/account")
|
||||||
class OAuthServerUserAccountApi(Resource):
|
class OAuthServerUserAccountApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@oauth_server_client_id_required
|
@oauth_server_client_id_required
|
||||||
@ -194,9 +198,3 @@ class OAuthServerUserAccountApi(Resource):
|
|||||||
"timezone": account.timezone,
|
"timezone": account.timezone,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(OAuthServerAppApi, "/oauth/provider")
|
|
||||||
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
|
|
||||||
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
|
|
||||||
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
|
|
||||||
|
|||||||
@ -1,42 +1,43 @@
|
|||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||||
from libs.login import current_user, login_required
|
from enums.cloud_plan import CloudPlan
|
||||||
from models.model import Account
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/billing/subscription")
|
||||||
class Subscription(Resource):
|
class Subscription(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
parser = (
|
||||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
reqparse.RequestParser()
|
||||||
args = parser.parse_args()
|
.add_argument(
|
||||||
assert isinstance(current_user, Account)
|
"plan",
|
||||||
|
type=str,
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
required=True,
|
||||||
assert current_user.current_tenant_id is not None
|
location="args",
|
||||||
return BillingService.get_subscription(
|
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
)
|
||||||
|
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||||
)
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
|
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/billing/invoices")
|
||||||
class Invoices(Resource):
|
class Invoices(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
assert isinstance(current_user, Account)
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
BillingService.is_tenant_owner_or_admin(current_user)
|
BillingService.is_tenant_owner_or_admin(current_user)
|
||||||
assert current_user.current_tenant_id is not None
|
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(Subscription, "/billing/subscription")
|
|
||||||
api.add_resource(Invoices, "/billing/invoices")
|
|
||||||
|
|||||||
@ -1,35 +1,31 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
|
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
from .. import api
|
from .. import console_ns
|
||||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/compliance/download")
|
||||||
class ComplianceApi(Resource):
|
class ComplianceApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@only_edition_cloud
|
@only_edition_cloud
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
ip_address = extract_remote_ip(request)
|
||||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||||
|
|
||||||
return BillingService.get_compliance_download_link(
|
return BillingService.get_compliance_download_link(
|
||||||
doc_name=args.doc_name,
|
doc_name=args.doc_name,
|
||||||
account_id=current_user.id,
|
account_id=current_user.id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
ip=ip_address,
|
ip=ip_address,
|
||||||
device_info=device_info,
|
device_info=device_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ComplianceApi, "/compliance/download")
|
|
||||||
|
|||||||
@ -3,40 +3,45 @@ from collections.abc import Generator
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
from core.indexing_runner import IndexingRunner
|
from core.indexing_runner import IndexingRunner
|
||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import DataSourceOauthBinding, Document
|
from models import DataSourceOauthBinding, Document
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from services.datasource_provider_service import DatasourceProviderService
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/data-source/integrates",
|
||||||
|
"/data-source/integrates/<uuid:binding_id>/<string:action>",
|
||||||
|
)
|
||||||
class DataSourceApi(Resource):
|
class DataSourceApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_list_fields)
|
@marshal_with(integrate_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# get workspace data source integrates
|
# get workspace data source integrates
|
||||||
data_source_integrates = db.session.scalars(
|
data_source_integrates = db.session.scalars(
|
||||||
select(DataSourceOauthBinding).where(
|
select(DataSourceOauthBinding).where(
|
||||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
DataSourceOauthBinding.tenant_id == current_tenant_id,
|
||||||
DataSourceOauthBinding.disabled == False,
|
DataSourceOauthBinding.disabled == False,
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
@ -109,19 +114,22 @@ class DataSourceApi(Resource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/notion/pre-import/pages")
|
||||||
class DataSourceNotionListApi(Resource):
|
class DataSourceNotionListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(integrate_notion_info_list_fields)
|
@marshal_with(integrate_notion_info_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
if not credential_id:
|
if not credential_id:
|
||||||
raise ValueError("Credential id is required.")
|
raise ValueError("Credential id is required.")
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
credential = datasource_provider_service.get_datasource_credentials(
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
provider="notion_datasource",
|
provider="notion_datasource",
|
||||||
plugin_id="langgenius/notion_datasource",
|
plugin_id="langgenius/notion_datasource",
|
||||||
@ -141,7 +149,7 @@ class DataSourceNotionListApi(Resource):
|
|||||||
documents = session.scalars(
|
documents = session.scalars(
|
||||||
select(Document).filter_by(
|
select(Document).filter_by(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
data_source_type="notion_import",
|
data_source_type="notion_import",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
)
|
)
|
||||||
@ -156,7 +164,7 @@ class DataSourceNotionListApi(Resource):
|
|||||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||||
datasource_name="notion_datasource",
|
datasource_name="notion_datasource",
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||||
)
|
)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
@ -196,17 +204,23 @@ class DataSourceNotionListApi(Resource):
|
|||||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||||
|
"/datasets/notion-indexing-estimate",
|
||||||
|
)
|
||||||
class DataSourceNotionApi(Resource):
|
class DataSourceNotionApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, workspace_id, page_id, page_type):
|
def get(self, workspace_id, page_id, page_type):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||||
if not credential_id:
|
if not credential_id:
|
||||||
raise ValueError("Credential id is required.")
|
raise ValueError("Credential id is required.")
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
credential = datasource_provider_service.get_datasource_credentials(
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
credential_id=credential_id,
|
credential_id=credential_id,
|
||||||
provider="notion_datasource",
|
provider="notion_datasource",
|
||||||
plugin_id="langgenius/notion_datasource",
|
plugin_id="langgenius/notion_datasource",
|
||||||
@ -220,7 +234,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
notion_obj_id=page_id,
|
notion_obj_id=page_id,
|
||||||
notion_page_type=page_type,
|
notion_page_type=page_type,
|
||||||
notion_access_token=credential.get("integration_secret"),
|
notion_access_token=credential.get("integration_secret"),
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
text_docs = extractor.extract()
|
text_docs = extractor.extract()
|
||||||
@ -230,12 +244,14 @@ class DataSourceNotionApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
|
||||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
parser = (
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||||
|
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
|
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# validate args
|
# validate args
|
||||||
@ -247,20 +263,22 @@ class DataSourceNotionApi(Resource):
|
|||||||
credential_id = notion_info.get("credential_id")
|
credential_id = notion_info.get("credential_id")
|
||||||
for page in notion_info["pages"]:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type=DatasourceType.NOTION,
|
||||||
notion_info={
|
notion_info=NotionInfo.model_validate(
|
||||||
"credential_id": credential_id,
|
{
|
||||||
"notion_workspace_id": workspace_id,
|
"credential_id": credential_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_page_type": page["type"],
|
"notion_obj_id": page["page_id"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"notion_page_type": page["type"],
|
||||||
},
|
"tenant_id": current_tenant_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
document_model=args["doc_form"],
|
document_model=args["doc_form"],
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
response = indexing_runner.indexing_estimate(
|
response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
extract_settings,
|
extract_settings,
|
||||||
args["process_rule"],
|
args["process_rule"],
|
||||||
args["doc_form"],
|
args["doc_form"],
|
||||||
@ -269,6 +287,7 @@ class DataSourceNotionApi(Resource):
|
|||||||
return response.model_dump(), 200
|
return response.model_dump(), 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
|
||||||
class DataSourceNotionDatasetSyncApi(Resource):
|
class DataSourceNotionDatasetSyncApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -285,6 +304,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync")
|
||||||
class DataSourceNotionDocumentSyncApi(Resource):
|
class DataSourceNotionDocumentSyncApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -301,16 +321,3 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
|||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
|
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
|
|
||||||
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
|
|
||||||
api.add_resource(
|
|
||||||
DataSourceNotionApi,
|
|
||||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
|
||||||
"/datasets/notion-indexing-estimate",
|
|
||||||
)
|
|
||||||
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
|
|
||||||
api.add_resource(
|
|
||||||
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import flask_restx
|
from typing import Any, cast
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
@ -23,29 +23,97 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.rag.datasource.vdb.vector_type import VectorType
|
from core.rag.datasource.vdb.vector_type import VectorType
|
||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import related_app_list
|
from fields.app_fields import related_app_list
|
||||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||||
from fields.document_fields import document_status_fields
|
from fields.document_fields import document_status_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
|
from libs.validators import validate_description_length
|
||||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
from models.provider_ids import ModelProviderID
|
from models.provider_ids import ModelProviderID
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name: str) -> str:
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
raise ValueError("Name must be between 1 to 40 characters.")
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def _validate_description_length(description):
|
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||||
if description and len(description) > 400:
|
"""
|
||||||
raise ValueError("Description cannot exceed 400 characters.")
|
Get supported retrieval methods based on vector database type.
|
||||||
return description
|
|
||||||
|
Args:
|
||||||
|
vector_type: Vector database type, can be None
|
||||||
|
is_mock: Whether this is a Mock API, affects MILVUS handling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing supported retrieval methods
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If vector_type is None or unsupported
|
||||||
|
"""
|
||||||
|
if vector_type is None:
|
||||||
|
raise ValueError("Vector store type is not configured.")
|
||||||
|
|
||||||
|
# Define vector database types that only support semantic search
|
||||||
|
semantic_only_types = {
|
||||||
|
VectorType.RELYT,
|
||||||
|
VectorType.TIDB_VECTOR,
|
||||||
|
VectorType.CHROMA,
|
||||||
|
VectorType.PGVECTO_RS,
|
||||||
|
VectorType.VIKINGDB,
|
||||||
|
VectorType.UPSTASH,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Define vector database types that support all retrieval methods
|
||||||
|
full_search_types = {
|
||||||
|
VectorType.QDRANT,
|
||||||
|
VectorType.WEAVIATE,
|
||||||
|
VectorType.OPENSEARCH,
|
||||||
|
VectorType.ANALYTICDB,
|
||||||
|
VectorType.MYSCALE,
|
||||||
|
VectorType.ORACLE,
|
||||||
|
VectorType.ELASTICSEARCH,
|
||||||
|
VectorType.ELASTICSEARCH_JA,
|
||||||
|
VectorType.PGVECTOR,
|
||||||
|
VectorType.VASTBASE,
|
||||||
|
VectorType.TIDB_ON_QDRANT,
|
||||||
|
VectorType.LINDORM,
|
||||||
|
VectorType.COUCHBASE,
|
||||||
|
VectorType.OPENGAUSS,
|
||||||
|
VectorType.OCEANBASE,
|
||||||
|
VectorType.TABLESTORE,
|
||||||
|
VectorType.HUAWEI_CLOUD,
|
||||||
|
VectorType.TENCENT,
|
||||||
|
VectorType.MATRIXONE,
|
||||||
|
VectorType.CLICKZETTA,
|
||||||
|
VectorType.BAIDU,
|
||||||
|
VectorType.ALIBABACLOUD_MYSQL,
|
||||||
|
}
|
||||||
|
|
||||||
|
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||||
|
full_methods = {
|
||||||
|
"retrieval_method": [
|
||||||
|
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
|
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||||
|
RetrievalMethod.HYBRID_SEARCH.value,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
if vector_type == VectorType.MILVUS:
|
||||||
|
return semantic_methods if is_mock else full_methods
|
||||||
|
|
||||||
|
if vector_type in semantic_only_types:
|
||||||
|
return semantic_methods
|
||||||
|
elif vector_type in full_search_types:
|
||||||
|
return full_methods
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets")
|
@console_ns.route("/datasets")
|
||||||
@ -68,6 +136,7 @@ class DatasetListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
ids = request.args.getlist("ids")
|
ids = request.args.getlist("ids")
|
||||||
@ -76,15 +145,15 @@ class DatasetListApi(Resource):
|
|||||||
tag_ids = request.args.getlist("tag_ids")
|
tag_ids = request.args.getlist("tag_ids")
|
||||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||||
if ids:
|
if ids:
|
||||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
|
||||||
else:
|
else:
|
||||||
datasets, total = DatasetService.get_datasets(
|
datasets, total = DatasetService.get_datasets(
|
||||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
|
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
|
||||||
)
|
)
|
||||||
|
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_manager = ProviderManager()
|
provider_manager = ProviderManager()
|
||||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||||
|
|
||||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||||
|
|
||||||
@ -92,7 +161,7 @@ class DatasetListApi(Resource):
|
|||||||
for embedding_model in embedding_models:
|
for embedding_model in embedding_models:
|
||||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||||
|
|
||||||
data = marshal(datasets, dataset_detail_fields)
|
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
|
||||||
for item in data:
|
for item in data:
|
||||||
# convert embedding_model_provider to plugin standard format
|
# convert embedding_model_provider to plugin standard format
|
||||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||||
@ -137,50 +206,53 @@ class DatasetListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="type is required. Name must be between 1 to 40 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument(
|
)
|
||||||
"description",
|
.add_argument(
|
||||||
type=_validate_description_length,
|
"description",
|
||||||
nullable=True,
|
type=validate_description_length,
|
||||||
required=False,
|
nullable=True,
|
||||||
default="",
|
required=False,
|
||||||
)
|
default="",
|
||||||
parser.add_argument(
|
)
|
||||||
"indexing_technique",
|
.add_argument(
|
||||||
type=str,
|
"indexing_technique",
|
||||||
location="json",
|
type=str,
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
location="json",
|
||||||
nullable=True,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
help="Invalid indexing technique.",
|
nullable=True,
|
||||||
)
|
help="Invalid indexing technique.",
|
||||||
parser.add_argument(
|
)
|
||||||
"external_knowledge_api_id",
|
.add_argument(
|
||||||
type=str,
|
"external_knowledge_api_id",
|
||||||
nullable=True,
|
type=str,
|
||||||
required=False,
|
nullable=True,
|
||||||
)
|
required=False,
|
||||||
parser.add_argument(
|
)
|
||||||
"provider",
|
.add_argument(
|
||||||
type=str,
|
"provider",
|
||||||
nullable=True,
|
type=str,
|
||||||
choices=Dataset.PROVIDER_LIST,
|
nullable=True,
|
||||||
required=False,
|
choices=Dataset.PROVIDER_LIST,
|
||||||
default="vendor",
|
required=False,
|
||||||
)
|
default="vendor",
|
||||||
parser.add_argument(
|
)
|
||||||
"external_knowledge_id",
|
.add_argument(
|
||||||
type=str,
|
"external_knowledge_id",
|
||||||
nullable=True,
|
type=str,
|
||||||
required=False,
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
@ -188,7 +260,7 @@ class DatasetListApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = DatasetService.create_empty_dataset(
|
dataset = DatasetService.create_empty_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
description=args["description"],
|
description=args["description"],
|
||||||
indexing_technique=args["indexing_technique"],
|
indexing_technique=args["indexing_technique"],
|
||||||
@ -216,6 +288,7 @@ class DatasetApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -224,7 +297,7 @@ class DatasetApi(Resource):
|
|||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
data = marshal(dataset, dataset_detail_fields)
|
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||||
if dataset.indexing_technique == "high_quality":
|
if dataset.indexing_technique == "high_quality":
|
||||||
if dataset.embedding_model_provider:
|
if dataset.embedding_model_provider:
|
||||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||||
@ -235,7 +308,7 @@ class DatasetApi(Resource):
|
|||||||
|
|
||||||
# check embedding setting
|
# check embedding setting
|
||||||
provider_manager = ProviderManager()
|
provider_manager = ProviderManager()
|
||||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||||
|
|
||||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||||
|
|
||||||
@ -281,73 +354,76 @@ class DatasetApi(Resource):
|
|||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
help="type is required. Name must be between 1 to 40 characters.",
|
nullable=False,
|
||||||
type=_validate_name,
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
)
|
||||||
parser.add_argument(
|
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||||
"indexing_technique",
|
.add_argument(
|
||||||
type=str,
|
"indexing_technique",
|
||||||
location="json",
|
type=str,
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
location="json",
|
||||||
nullable=True,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
help="Invalid indexing technique.",
|
nullable=True,
|
||||||
)
|
help="Invalid indexing technique.",
|
||||||
parser.add_argument(
|
)
|
||||||
"permission",
|
.add_argument(
|
||||||
type=str,
|
"permission",
|
||||||
location="json",
|
type=str,
|
||||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
location="json",
|
||||||
help="Invalid permission.",
|
choices=(
|
||||||
)
|
DatasetPermissionEnum.ONLY_ME,
|
||||||
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
DatasetPermissionEnum.ALL_TEAM,
|
||||||
parser.add_argument(
|
DatasetPermissionEnum.PARTIAL_TEAM,
|
||||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
),
|
||||||
)
|
help="Invalid permission.",
|
||||||
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
)
|
||||||
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||||
|
.add_argument(
|
||||||
parser.add_argument(
|
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||||
"external_retrieval_model",
|
)
|
||||||
type=dict,
|
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||||
required=False,
|
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||||
nullable=True,
|
.add_argument(
|
||||||
location="json",
|
"external_retrieval_model",
|
||||||
help="Invalid external retrieval model.",
|
type=dict,
|
||||||
)
|
required=False,
|
||||||
|
nullable=True,
|
||||||
parser.add_argument(
|
location="json",
|
||||||
"external_knowledge_id",
|
help="Invalid external retrieval model.",
|
||||||
type=str,
|
)
|
||||||
required=False,
|
.add_argument(
|
||||||
nullable=True,
|
"external_knowledge_id",
|
||||||
location="json",
|
type=str,
|
||||||
help="Invalid external knowledge id.",
|
required=False,
|
||||||
)
|
nullable=True,
|
||||||
|
location="json",
|
||||||
parser.add_argument(
|
help="Invalid external knowledge id.",
|
||||||
"external_knowledge_api_id",
|
)
|
||||||
type=str,
|
.add_argument(
|
||||||
required=False,
|
"external_knowledge_api_id",
|
||||||
nullable=True,
|
type=str,
|
||||||
location="json",
|
required=False,
|
||||||
help="Invalid external knowledge api id.",
|
nullable=True,
|
||||||
)
|
location="json",
|
||||||
|
help="Invalid external knowledge api id.",
|
||||||
parser.add_argument(
|
)
|
||||||
"icon_info",
|
.add_argument(
|
||||||
type=dict,
|
"icon_info",
|
||||||
required=False,
|
type=dict,
|
||||||
nullable=True,
|
required=False,
|
||||||
location="json",
|
nullable=True,
|
||||||
help="Invalid icon info.",
|
location="json",
|
||||||
|
help="Invalid icon info.",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check embedding model setting
|
# check embedding model setting
|
||||||
if (
|
if (
|
||||||
@ -369,8 +445,8 @@ class DatasetApi(Resource):
|
|||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
result_data = marshal(dataset, dataset_detail_fields)
|
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||||
tenant_id = current_user.current_tenant_id
|
tenant_id = current_tenant_id
|
||||||
|
|
||||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||||
DatasetPermissionService.update_partial_member_list(
|
DatasetPermissionService.update_partial_member_list(
|
||||||
@ -394,9 +470,9 @@ class DatasetApi(Resource):
|
|||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id):
|
def delete(self, dataset_id):
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||||
if not (current_user.is_editor or current_user.is_dataset_operator):
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -435,6 +511,7 @@ class DatasetQueryApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -469,32 +546,31 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
||||||
parser.add_argument(
|
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||||
"indexing_technique",
|
.add_argument(
|
||||||
type=str,
|
"indexing_technique",
|
||||||
required=True,
|
type=str,
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
required=True,
|
||||||
nullable=True,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
location="json",
|
nullable=True,
|
||||||
)
|
location="json",
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
)
|
||||||
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
parser.add_argument(
|
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
# validate args
|
# validate args
|
||||||
DocumentService.estimate_args_validate(args)
|
DocumentService.estimate_args_validate(args)
|
||||||
extract_settings = []
|
extract_settings = []
|
||||||
if args["info_list"]["data_source_type"] == "upload_file":
|
if args["info_list"]["data_source_type"] == "upload_file":
|
||||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||||
file_details = db.session.scalars(
|
file_details = db.session.scalars(
|
||||||
select(UploadFile).where(
|
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
|
||||||
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
|
|
||||||
)
|
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
if file_details is None:
|
if file_details is None:
|
||||||
@ -503,7 +579,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
if file_details:
|
if file_details:
|
||||||
for file_detail in file_details:
|
for file_detail in file_details:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.FILE.value,
|
datasource_type=DatasourceType.FILE,
|
||||||
upload_file=file_detail,
|
upload_file=file_detail,
|
||||||
document_model=args["doc_form"],
|
document_model=args["doc_form"],
|
||||||
)
|
)
|
||||||
@ -515,14 +591,16 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
credential_id = notion_info.get("credential_id")
|
credential_id = notion_info.get("credential_id")
|
||||||
for page in notion_info["pages"]:
|
for page in notion_info["pages"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type=DatasourceType.NOTION,
|
||||||
notion_info={
|
notion_info=NotionInfo.model_validate(
|
||||||
"credential_id": credential_id,
|
{
|
||||||
"notion_workspace_id": workspace_id,
|
"credential_id": credential_id,
|
||||||
"notion_obj_id": page["page_id"],
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_page_type": page["type"],
|
"notion_obj_id": page["page_id"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"notion_page_type": page["type"],
|
||||||
},
|
"tenant_id": current_tenant_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
document_model=args["doc_form"],
|
document_model=args["doc_form"],
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
@ -530,15 +608,17 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
website_info_list = args["info_list"]["website_info_list"]
|
website_info_list = args["info_list"]["website_info_list"]
|
||||||
for url in website_info_list["urls"]:
|
for url in website_info_list["urls"]:
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.WEBSITE.value,
|
datasource_type=DatasourceType.WEBSITE,
|
||||||
website_info={
|
website_info=WebsiteInfo.model_validate(
|
||||||
"provider": website_info_list["provider"],
|
{
|
||||||
"job_id": website_info_list["job_id"],
|
"provider": website_info_list["provider"],
|
||||||
"url": url,
|
"job_id": website_info_list["job_id"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"url": url,
|
||||||
"mode": "crawl",
|
"tenant_id": current_tenant_id,
|
||||||
"only_main_content": website_info_list["only_main_content"],
|
"mode": "crawl",
|
||||||
},
|
"only_main_content": website_info_list["only_main_content"],
|
||||||
|
}
|
||||||
|
),
|
||||||
document_model=args["doc_form"],
|
document_model=args["doc_form"],
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
@ -547,7 +627,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
try:
|
try:
|
||||||
response = indexing_runner.indexing_estimate(
|
response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
extract_settings,
|
extract_settings,
|
||||||
args["process_rule"],
|
args["process_rule"],
|
||||||
args["doc_form"],
|
args["doc_form"],
|
||||||
@ -578,6 +658,7 @@ class DatasetRelatedAppListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(related_app_list)
|
@marshal_with(related_app_list)
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -609,11 +690,10 @@ class DatasetIndexingStatusApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
documents = db.session.scalars(
|
documents = db.session.scalars(
|
||||||
select(Document).where(
|
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
|
||||||
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
).all()
|
).all()
|
||||||
documents_status = []
|
documents_status = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
@ -665,10 +745,9 @@ class DatasetApiKeyApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_key_list)
|
@marshal_with(api_key_list)
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
keys = db.session.scalars(
|
keys = db.session.scalars(
|
||||||
select(ApiToken).where(
|
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
).all()
|
).all()
|
||||||
return {"items": keys}
|
return {"items": keys}
|
||||||
|
|
||||||
@ -678,17 +757,18 @@ class DatasetApiKeyApi(Resource):
|
|||||||
@marshal_with(api_key_fields)
|
@marshal_with(api_key_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
if not current_user.is_admin_or_owner:
|
if not current_user.is_admin_or_owner:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
current_key_count = (
|
current_key_count = (
|
||||||
db.session.query(ApiToken)
|
db.session.query(ApiToken)
|
||||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_key_count >= self.max_keys:
|
if current_key_count >= self.max_keys:
|
||||||
flask_restx.abort(
|
api.abort(
|
||||||
400,
|
400,
|
||||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||||
code="max_keys_exceeded",
|
code="max_keys_exceeded",
|
||||||
@ -696,7 +776,7 @@ class DatasetApiKeyApi(Resource):
|
|||||||
|
|
||||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||||
api_token = ApiToken()
|
api_token = ApiToken()
|
||||||
api_token.tenant_id = current_user.current_tenant_id
|
api_token.tenant_id = current_tenant_id
|
||||||
api_token.token = key
|
api_token.token = key
|
||||||
api_token.type = self.resource_type
|
api_token.type = self.resource_type
|
||||||
db.session.add(api_token)
|
db.session.add(api_token)
|
||||||
@ -716,6 +796,7 @@ class DatasetApiDeleteApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, api_key_id):
|
def delete(self, api_key_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
api_key_id = str(api_key_id)
|
api_key_id = str(api_key_id)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin or owner
|
# The role of the current user in the ta table must be admin or owner
|
||||||
@ -725,7 +806,7 @@ class DatasetApiDeleteApi(Resource):
|
|||||||
key = (
|
key = (
|
||||||
db.session.query(ApiToken)
|
db.session.query(ApiToken)
|
||||||
.where(
|
.where(
|
||||||
ApiToken.tenant_id == current_user.current_tenant_id,
|
ApiToken.tenant_id == current_tenant_id,
|
||||||
ApiToken.type == self.resource_type,
|
ApiToken.type == self.resource_type,
|
||||||
ApiToken.id == api_key_id,
|
ApiToken.id == api_key_id,
|
||||||
)
|
)
|
||||||
@ -733,7 +814,7 @@ class DatasetApiDeleteApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if key is None:
|
if key is None:
|
||||||
flask_restx.abort(404, message="API key not found")
|
api.abort(404, message="API key not found")
|
||||||
|
|
||||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
@ -776,49 +857,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
vector_type = dify_config.VECTOR_STORE
|
vector_type = dify_config.VECTOR_STORE
|
||||||
match vector_type:
|
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||||
case (
|
|
||||||
VectorType.RELYT
|
|
||||||
| VectorType.TIDB_VECTOR
|
|
||||||
| VectorType.CHROMA
|
|
||||||
| VectorType.PGVECTO_RS
|
|
||||||
| VectorType.VIKINGDB
|
|
||||||
| VectorType.UPSTASH
|
|
||||||
):
|
|
||||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
|
||||||
case (
|
|
||||||
VectorType.QDRANT
|
|
||||||
| VectorType.WEAVIATE
|
|
||||||
| VectorType.OPENSEARCH
|
|
||||||
| VectorType.ANALYTICDB
|
|
||||||
| VectorType.MYSCALE
|
|
||||||
| VectorType.ORACLE
|
|
||||||
| VectorType.ELASTICSEARCH
|
|
||||||
| VectorType.ELASTICSEARCH_JA
|
|
||||||
| VectorType.PGVECTOR
|
|
||||||
| VectorType.VASTBASE
|
|
||||||
| VectorType.TIDB_ON_QDRANT
|
|
||||||
| VectorType.LINDORM
|
|
||||||
| VectorType.COUCHBASE
|
|
||||||
| VectorType.MILVUS
|
|
||||||
| VectorType.OPENGAUSS
|
|
||||||
| VectorType.OCEANBASE
|
|
||||||
| VectorType.TABLESTORE
|
|
||||||
| VectorType.HUAWEI_CLOUD
|
|
||||||
| VectorType.TENCENT
|
|
||||||
| VectorType.MATRIXONE
|
|
||||||
| VectorType.CLICKZETTA
|
|
||||||
| VectorType.BAIDU
|
|
||||||
):
|
|
||||||
return {
|
|
||||||
"retrieval_method": [
|
|
||||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
|
||||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
|
||||||
RetrievalMethod.HYBRID_SEARCH.value,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||||
@ -831,48 +870,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, vector_type):
|
def get(self, vector_type):
|
||||||
match vector_type:
|
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||||
case (
|
|
||||||
VectorType.MILVUS
|
|
||||||
| VectorType.RELYT
|
|
||||||
| VectorType.TIDB_VECTOR
|
|
||||||
| VectorType.CHROMA
|
|
||||||
| VectorType.PGVECTO_RS
|
|
||||||
| VectorType.VIKINGDB
|
|
||||||
| VectorType.UPSTASH
|
|
||||||
):
|
|
||||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
|
||||||
case (
|
|
||||||
VectorType.QDRANT
|
|
||||||
| VectorType.WEAVIATE
|
|
||||||
| VectorType.OPENSEARCH
|
|
||||||
| VectorType.ANALYTICDB
|
|
||||||
| VectorType.MYSCALE
|
|
||||||
| VectorType.ORACLE
|
|
||||||
| VectorType.ELASTICSEARCH
|
|
||||||
| VectorType.ELASTICSEARCH_JA
|
|
||||||
| VectorType.COUCHBASE
|
|
||||||
| VectorType.PGVECTOR
|
|
||||||
| VectorType.VASTBASE
|
|
||||||
| VectorType.LINDORM
|
|
||||||
| VectorType.OPENGAUSS
|
|
||||||
| VectorType.OCEANBASE
|
|
||||||
| VectorType.TABLESTORE
|
|
||||||
| VectorType.TENCENT
|
|
||||||
| VectorType.HUAWEI_CLOUD
|
|
||||||
| VectorType.MATRIXONE
|
|
||||||
| VectorType.CLICKZETTA
|
|
||||||
| VectorType.BAIDU
|
|
||||||
):
|
|
||||||
return {
|
|
||||||
"retrieval_method": [
|
|
||||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
|
||||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
|
||||||
RetrievalMethod.HYBRID_SEARCH.value,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||||
@ -907,6 +905,7 @@ class DatasetPermissionUserListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from argparse import ArgumentTypeError
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Literal, cast
|
from typing import Literal, cast
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy import asc, desc, select
|
from sqlalchemy import asc, desc, select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
@ -43,7 +43,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.document_fields import (
|
from fields.document_fields import (
|
||||||
dataset_and_document_fields,
|
dataset_and_document_fields,
|
||||||
@ -52,7 +52,7 @@ from fields.document_fields import (
|
|||||||
document_with_segments_fields,
|
document_with_segments_fields,
|
||||||
)
|
)
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||||
from models.dataset import DocumentPipelineExecutionLog
|
from models.dataset import DocumentPipelineExecutionLog
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
@ -63,6 +63,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class DocumentResource(Resource):
|
class DocumentResource(Resource):
|
||||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
@ -77,12 +78,13 @@ class DocumentResource(Resource):
|
|||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
if document.tenant_id != current_user.current_tenant_id:
|
if document.tenant_id != current_tenant_id:
|
||||||
raise Forbidden("No permission.")
|
raise Forbidden("No permission.")
|
||||||
|
|
||||||
return document
|
return document
|
||||||
|
|
||||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
@ -110,6 +112,7 @@ class GetProcessRuleApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
req_data = request.args
|
req_data = request.args
|
||||||
|
|
||||||
document_id = req_data.get("document_id")
|
document_id = req_data.get("document_id")
|
||||||
@ -166,6 +169,7 @@ class DatasetDocumentListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id):
|
def get(self, dataset_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
@ -197,7 +201,7 @@ class DatasetDocumentListApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
search = f"%{search}%"
|
search = f"%{search}%"
|
||||||
@ -211,13 +215,13 @@ class DatasetDocumentListApi(Resource):
|
|||||||
|
|
||||||
if sort == "hit_count":
|
if sort == "hit_count":
|
||||||
sub_query = (
|
sub_query = (
|
||||||
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||||
.group_by(DocumentSegment.document_id)
|
.group_by(DocumentSegment.document_id)
|
||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
|
|
||||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||||
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
|
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||||
sort_logic(Document.position),
|
sort_logic(Document.position),
|
||||||
)
|
)
|
||||||
elif sort == "created_at":
|
elif sort == "created_at":
|
||||||
@ -271,6 +275,7 @@ class DatasetDocumentListApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
|
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -287,23 +292,23 @@ class DatasetDocumentListApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
.add_argument(
|
||||||
)
|
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||||
parser.add_argument("data_source", type=dict, required=False, location="json")
|
)
|
||||||
parser.add_argument("process_rule", type=dict, required=False, location="json")
|
.add_argument("data_source", type=dict, required=False, location="json")
|
||||||
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
.add_argument("process_rule", type=dict, required=False, location="json")
|
||||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(
|
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
knowledge_config = KnowledgeConfig(**args)
|
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||||
|
|
||||||
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
||||||
raise ValueError("indexing_technique is required.")
|
raise ValueError("indexing_technique is required.")
|
||||||
@ -370,37 +375,38 @@ class DatasetInitApi(Resource):
|
|||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"indexing_technique",
|
.add_argument(
|
||||||
type=str,
|
"indexing_technique",
|
||||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
type=str,
|
||||||
required=True,
|
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||||
nullable=False,
|
required=True,
|
||||||
location="json",
|
nullable=False,
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||||
|
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||||
|
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||||
|
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||||
|
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||||
|
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
|
||||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
|
||||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
|
||||||
parser.add_argument(
|
|
||||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
|
||||||
)
|
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
|
||||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
|
||||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
knowledge_config = KnowledgeConfig(**args)
|
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||||
if knowledge_config.indexing_technique == "high_quality":
|
if knowledge_config.indexing_technique == "high_quality":
|
||||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=args["embedding_model_provider"],
|
provider=args["embedding_model_provider"],
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=args["embedding_model"],
|
model=args["embedding_model"],
|
||||||
@ -417,7 +423,9 @@ class DatasetInitApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||||
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
|
tenant_id=current_tenant_id,
|
||||||
|
knowledge_config=knowledge_config,
|
||||||
|
account=current_user,
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError as ex:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError(ex.description)
|
raise ProviderNotInitializeError(ex.description)
|
||||||
@ -443,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id):
|
def get(self, dataset_id, document_id):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
@ -451,7 +460,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
raise DocumentAlreadyFinishedError()
|
raise DocumentAlreadyFinishedError()
|
||||||
|
|
||||||
data_process_rule = document.dataset_process_rule
|
data_process_rule = document.dataset_process_rule
|
||||||
data_process_rule_dict = data_process_rule.to_dict()
|
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||||
|
|
||||||
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
|
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
|
||||||
|
|
||||||
@ -471,14 +480,14 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
raise NotFound("File not found.")
|
raise NotFound("File not found.")
|
||||||
|
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
|
||||||
)
|
)
|
||||||
|
|
||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
estimate_response = indexing_runner.indexing_estimate(
|
estimate_response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
[extract_setting],
|
[extract_setting],
|
||||||
data_process_rule_dict,
|
data_process_rule_dict,
|
||||||
document.doc_form,
|
document.doc_form,
|
||||||
@ -507,13 +516,14 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, batch):
|
def get(self, dataset_id, batch):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
batch = str(batch)
|
batch = str(batch)
|
||||||
documents = self.get_batch_documents(dataset_id, batch)
|
documents = self.get_batch_documents(dataset_id, batch)
|
||||||
if not documents:
|
if not documents:
|
||||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||||
data_process_rule = documents[0].dataset_process_rule
|
data_process_rule = documents[0].dataset_process_rule
|
||||||
data_process_rule_dict = data_process_rule.to_dict()
|
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||||
extract_settings = []
|
extract_settings = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
if document.indexing_status in {"completed", "error"}:
|
if document.indexing_status in {"completed", "error"}:
|
||||||
@ -526,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
file_id = data_source_info["upload_file_id"]
|
file_id = data_source_info["upload_file_id"]
|
||||||
file_detail = (
|
file_detail = (
|
||||||
db.session.query(UploadFile)
|
db.session.query(UploadFile)
|
||||||
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
|
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -534,7 +544,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
raise NotFound("File not found.")
|
raise NotFound("File not found.")
|
||||||
|
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
|
|
||||||
@ -542,14 +552,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
if not data_source_info:
|
if not data_source_info:
|
||||||
continue
|
continue
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.NOTION.value,
|
datasource_type=DatasourceType.NOTION,
|
||||||
notion_info={
|
notion_info=NotionInfo.model_validate(
|
||||||
"credential_id": data_source_info["credential_id"],
|
{
|
||||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
"credential_id": data_source_info["credential_id"],
|
||||||
"notion_obj_id": data_source_info["notion_page_id"],
|
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||||
"notion_page_type": data_source_info["type"],
|
"notion_obj_id": data_source_info["notion_page_id"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"notion_page_type": data_source_info["type"],
|
||||||
},
|
"tenant_id": current_tenant_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
document_model=document.doc_form,
|
document_model=document.doc_form,
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
@ -557,15 +569,17 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
if not data_source_info:
|
if not data_source_info:
|
||||||
continue
|
continue
|
||||||
extract_setting = ExtractSetting(
|
extract_setting = ExtractSetting(
|
||||||
datasource_type=DatasourceType.WEBSITE.value,
|
datasource_type=DatasourceType.WEBSITE,
|
||||||
website_info={
|
website_info=WebsiteInfo.model_validate(
|
||||||
"provider": data_source_info["provider"],
|
{
|
||||||
"job_id": data_source_info["job_id"],
|
"provider": data_source_info["provider"],
|
||||||
"url": data_source_info["url"],
|
"job_id": data_source_info["job_id"],
|
||||||
"tenant_id": current_user.current_tenant_id,
|
"url": data_source_info["url"],
|
||||||
"mode": data_source_info["mode"],
|
"tenant_id": current_tenant_id,
|
||||||
"only_main_content": data_source_info["only_main_content"],
|
"mode": data_source_info["mode"],
|
||||||
},
|
"only_main_content": data_source_info["only_main_content"],
|
||||||
|
}
|
||||||
|
),
|
||||||
document_model=document.doc_form,
|
document_model=document.doc_form,
|
||||||
)
|
)
|
||||||
extract_settings.append(extract_setting)
|
extract_settings.append(extract_setting)
|
||||||
@ -575,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
indexing_runner = IndexingRunner()
|
indexing_runner = IndexingRunner()
|
||||||
try:
|
try:
|
||||||
response = indexing_runner.indexing_estimate(
|
response = indexing_runner.indexing_estimate(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
extract_settings,
|
extract_settings,
|
||||||
data_process_rule_dict,
|
data_process_rule_dict,
|
||||||
document.doc_form,
|
document.doc_form,
|
||||||
@ -732,7 +746,7 @@ class DocumentApi(DocumentResource):
|
|||||||
"name": document.name,
|
"name": document.name,
|
||||||
"created_from": document.created_from,
|
"created_from": document.created_from,
|
||||||
"created_by": document.created_by,
|
"created_by": document.created_by,
|
||||||
"created_at": document.created_at.timestamp(),
|
"created_at": int(document.created_at.timestamp()),
|
||||||
"tokens": document.tokens,
|
"tokens": document.tokens,
|
||||||
"indexing_status": document.indexing_status,
|
"indexing_status": document.indexing_status,
|
||||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||||
@ -752,7 +766,7 @@ class DocumentApi(DocumentResource):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||||
document_process_rules = document.dataset_process_rule.to_dict()
|
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||||
data_source_info = document.data_source_detail_dict
|
data_source_info = document.data_source_detail_dict
|
||||||
response = {
|
response = {
|
||||||
"id": document.id,
|
"id": document.id,
|
||||||
@ -765,7 +779,7 @@ class DocumentApi(DocumentResource):
|
|||||||
"name": document.name,
|
"name": document.name,
|
||||||
"created_from": document.created_from,
|
"created_from": document.created_from,
|
||||||
"created_by": document.created_by,
|
"created_by": document.created_by,
|
||||||
"created_at": document.created_at.timestamp(),
|
"created_at": int(document.created_at.timestamp()),
|
||||||
"tokens": document.tokens,
|
"tokens": document.tokens,
|
||||||
"indexing_status": document.indexing_status,
|
"indexing_status": document.indexing_status,
|
||||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||||
@ -826,6 +840,7 @@ class DocumentProcessingApi(DocumentResource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
|
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
@ -876,6 +891,7 @@ class DocumentMetadataApi(DocumentResource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def put(self, dataset_id, document_id):
|
def put(self, dataset_id, document_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
@ -923,6 +939,7 @@ class DocumentStatusApi(DocumentResource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -1026,8 +1043,9 @@ class DocumentRetryApi(DocumentResource):
|
|||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
"""retry document."""
|
"""retry document."""
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json")
|
"document_ids", type=list, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -1069,12 +1087,14 @@ class DocumentRenameApi(DocumentResource):
|
|||||||
@marshal_with(document_fields)
|
@marshal_with(document_fields)
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
if not dataset:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1092,6 +1112,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id):
|
def get(self, dataset_id, document_id):
|
||||||
"""sync website document."""
|
"""sync website document."""
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@ -1100,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
|||||||
document = DocumentService.get_document(dataset.id, document_id)
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
if document.tenant_id != current_user.current_tenant_id:
|
if document.tenant_id != current_tenant_id:
|
||||||
raise Forbidden("No permission.")
|
raise Forbidden("No permission.")
|
||||||
if document.data_source_type != "website_crawl":
|
if document.data_source_type != "website_crawl":
|
||||||
raise ValueError("Document is not a website document.")
|
raise ValueError("Document is not a website document.")
|
||||||
@ -1113,6 +1134,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log")
|
||||||
class DocumentPipelineExecutionLogApi(DocumentResource):
|
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -1146,29 +1168,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
|||||||
"input_data": log.input_data,
|
"input_data": log.input_data,
|
||||||
"datasource_node_id": log.datasource_node_id,
|
"datasource_node_id": log.datasource_node_id,
|
||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
|
||||||
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
|
||||||
api.add_resource(DatasetInitApi, "/datasets/init")
|
|
||||||
api.add_resource(
|
|
||||||
DocumentIndexingEstimateApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate"
|
|
||||||
)
|
|
||||||
api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
|
||||||
api.add_resource(DocumentBatchIndexingStatusApi, "/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
|
||||||
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
|
||||||
api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
|
||||||
api.add_resource(
|
|
||||||
DocumentProcessingApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>"
|
|
||||||
)
|
|
||||||
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
|
||||||
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
|
||||||
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
|
||||||
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
|
||||||
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
|
||||||
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
|
||||||
|
|
||||||
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
|
||||||
api.add_resource(
|
|
||||||
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal, reqparse
|
from flask_restx import Resource, marshal, reqparse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import ProviderNotInitializeError
|
from controllers.console.app.error import ProviderNotInitializeError
|
||||||
from controllers.console.datasets.error import (
|
from controllers.console.datasets.error import (
|
||||||
ChildChunkDeleteIndexError,
|
ChildChunkDeleteIndexError,
|
||||||
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.dataset import ChildChunk, DocumentSegment
|
from models.dataset import ChildChunk, DocumentSegment
|
||||||
from models.model import UploadFile
|
from models.model import UploadFile
|
||||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||||
@ -37,11 +36,14 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
|||||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||||
class DatasetDocumentSegmentListApi(Resource):
|
class DatasetDocumentSegmentListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id):
|
def get(self, dataset_id, document_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
document_id = str(document_id)
|
document_id = str(document_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -58,13 +60,15 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("limit", type=int, default=20, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("status", type=str, action="append", default=[], location="args")
|
.add_argument("limit", type=int, default=20, location="args")
|
||||||
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
|
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||||
parser.add_argument("enabled", type=str, default="all", location="args")
|
.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
.add_argument("enabled", type=str, default="all", location="args")
|
||||||
parser.add_argument("page", type=int, default=1, location="args")
|
.add_argument("keyword", type=str, default=None, location="args")
|
||||||
|
.add_argument("page", type=int, default=1, location="args")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -78,7 +82,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
select(DocumentSegment)
|
select(DocumentSegment)
|
||||||
.where(
|
.where(
|
||||||
DocumentSegment.document_id == str(document_id),
|
DocumentSegment.document_id == str(document_id),
|
||||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
DocumentSegment.tenant_id == current_tenant_id,
|
||||||
)
|
)
|
||||||
.order_by(DocumentSegment.position.asc())
|
.order_by(DocumentSegment.position.asc())
|
||||||
)
|
)
|
||||||
@ -114,6 +118,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id, document_id):
|
def delete(self, dataset_id, document_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -139,6 +145,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
|||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||||
class DatasetDocumentSegmentApi(Resource):
|
class DatasetDocumentSegmentApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -146,6 +153,8 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, action):
|
def patch(self, dataset_id, document_id, action):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@ -169,7 +178,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -193,6 +202,7 @@ class DatasetDocumentSegmentApi(Resource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||||
class DatasetDocumentSegmentAddApi(Resource):
|
class DatasetDocumentSegmentAddApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -201,6 +211,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -218,7 +230,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -234,16 +246,19 @@ class DatasetDocumentSegmentAddApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SegmentService.segment_create_args_validate(args, document)
|
SegmentService.segment_create_args_validate(args, document)
|
||||||
segment = SegmentService.create_segment(args, document, dataset)
|
segment = SegmentService.create_segment(args, document, dataset)
|
||||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -251,6 +266,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, segment_id):
|
def patch(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -268,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -283,7 +300,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -296,16 +313,18 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(
|
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
.add_argument(
|
||||||
|
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
SegmentService.segment_create_args_validate(args, document)
|
SegmentService.segment_create_args_validate(args, document)
|
||||||
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
|
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
|
||||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@ -313,6 +332,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id, document_id, segment_id):
|
def delete(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -329,7 +350,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -345,6 +366,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
|||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||||
|
"/datasets/batch_import_status/<uuid:job_id>",
|
||||||
|
)
|
||||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -353,6 +378,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id, document_id):
|
def post(self, dataset_id, document_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -364,8 +391,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json")
|
"upload_file_id", type=str, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
upload_file_id = args["upload_file_id"]
|
upload_file_id = args["upload_file_id"]
|
||||||
|
|
||||||
@ -384,7 +412,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
# send batch add segments task
|
# send batch add segments task
|
||||||
redis_client.setnx(indexing_cache_key, "waiting")
|
redis_client.setnx(indexing_cache_key, "waiting")
|
||||||
batch_create_segment_to_index_task.delay(
|
batch_create_segment_to_index_task.delay(
|
||||||
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
str(job_id),
|
||||||
|
upload_file_id,
|
||||||
|
dataset_id,
|
||||||
|
document_id,
|
||||||
|
current_tenant_id,
|
||||||
|
current_user.id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": str(e)}, 500
|
return {"error": str(e)}, 500
|
||||||
@ -393,7 +426,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, job_id):
|
def get(self, job_id=None, dataset_id=None, document_id=None):
|
||||||
|
if job_id is None:
|
||||||
|
raise NotFound("The job does not exist.")
|
||||||
job_id = str(job_id)
|
job_id = str(job_id)
|
||||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||||
cache_result = redis_client.get(indexing_cache_key)
|
cache_result = redis_client.get(indexing_cache_key)
|
||||||
@ -403,6 +438,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
|||||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
|
||||||
class ChildChunkAddApi(Resource):
|
class ChildChunkAddApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -411,6 +447,8 @@ class ChildChunkAddApi(Resource):
|
|||||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self, dataset_id, document_id, segment_id):
|
def post(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -425,7 +463,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -437,7 +475,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
try:
|
try:
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
model_manager.get_model_instance(
|
model_manager.get_model_instance(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=dataset.embedding_model_provider,
|
provider=dataset.embedding_model_provider,
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
model=dataset.embedding_model,
|
model=dataset.embedding_model,
|
||||||
@ -453,11 +491,13 @@ class ChildChunkAddApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
"content", type=str, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
|
content = args["content"]
|
||||||
|
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
|
||||||
except ChildChunkIndexingServiceError as e:
|
except ChildChunkIndexingServiceError as e:
|
||||||
raise ChildChunkIndexingError(str(e))
|
raise ChildChunkIndexingError(str(e))
|
||||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||||
@ -466,6 +506,8 @@ class ChildChunkAddApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, dataset_id, document_id, segment_id):
|
def get(self, dataset_id, document_id, segment_id):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -482,15 +524,17 @@ class ChildChunkAddApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
raise NotFound("Segment not found.")
|
raise NotFound("Segment not found.")
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("limit", type=int, default=20, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
.add_argument("limit", type=int, default=20, location="args")
|
||||||
parser.add_argument("page", type=int, default=1, location="args")
|
.add_argument("keyword", type=str, default=None, location="args")
|
||||||
|
.add_argument("page", type=int, default=1, location="args")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -513,6 +557,8 @@ class ChildChunkAddApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, segment_id):
|
def patch(self, dataset_id, document_id, segment_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -529,7 +575,7 @@ class ChildChunkAddApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -542,23 +588,30 @@ class ChildChunkAddApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
|
"chunks", type=list, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
|
chunks_data = args["chunks"]
|
||||||
|
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
|
||||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||||
except ChildChunkIndexingServiceError as e:
|
except ChildChunkIndexingServiceError as e:
|
||||||
raise ChildChunkIndexingError(str(e))
|
raise ChildChunkIndexingError(str(e))
|
||||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>"
|
||||||
|
)
|
||||||
class ChildChunkUpdateApi(Resource):
|
class ChildChunkUpdateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -575,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -586,7 +639,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
db.session.query(ChildChunk)
|
db.session.query(ChildChunk)
|
||||||
.where(
|
.where(
|
||||||
ChildChunk.id == str(child_chunk_id),
|
ChildChunk.id == str(child_chunk_id),
|
||||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
ChildChunk.tenant_id == current_tenant_id,
|
||||||
ChildChunk.segment_id == segment.id,
|
ChildChunk.segment_id == segment.id,
|
||||||
ChildChunk.document_id == document_id,
|
ChildChunk.document_id == document_id,
|
||||||
)
|
)
|
||||||
@ -613,6 +666,8 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
@cloud_edition_billing_resource_check("vector_space")
|
@cloud_edition_billing_resource_check("vector_space")
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
# check dataset
|
# check dataset
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
@ -629,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
segment_id = str(segment_id)
|
segment_id = str(segment_id)
|
||||||
segment = (
|
segment = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.query(DocumentSegment)
|
||||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if not segment:
|
if not segment:
|
||||||
@ -640,7 +695,7 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
db.session.query(ChildChunk)
|
db.session.query(ChildChunk)
|
||||||
.where(
|
.where(
|
||||||
ChildChunk.id == str(child_chunk_id),
|
ChildChunk.id == str(child_chunk_id),
|
||||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
ChildChunk.tenant_id == current_tenant_id,
|
||||||
ChildChunk.segment_id == segment.id,
|
ChildChunk.segment_id == segment.id,
|
||||||
ChildChunk.document_id == document_id,
|
ChildChunk.document_id == document_id,
|
||||||
)
|
)
|
||||||
@ -656,37 +711,13 @@ class ChildChunkUpdateApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
# validate args
|
# validate args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
"content", type=str, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
child_chunk = SegmentService.update_child_chunk(
|
content = args["content"]
|
||||||
args.get("content"), child_chunk, segment, document, dataset
|
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
|
||||||
)
|
|
||||||
except ChildChunkIndexingServiceError as e:
|
except ChildChunkIndexingServiceError as e:
|
||||||
raise ChildChunkIndexingError(str(e))
|
raise ChildChunkIndexingError(str(e))
|
||||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
|
||||||
api.add_resource(
|
|
||||||
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
|
|
||||||
)
|
|
||||||
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
|
||||||
api.add_resource(
|
|
||||||
DatasetDocumentSegmentUpdateApi,
|
|
||||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
DatasetDocumentSegmentBatchImportApi,
|
|
||||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
|
||||||
"/datasets/batch_import_status/<uuid:job_id>",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
ChildChunkAddApi,
|
|
||||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
ChildChunkUpdateApi,
|
|
||||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal, reqparse
|
from flask_restx import Resource, fields, marshal, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
@ -8,14 +7,14 @@ from controllers.console import api, console_ns
|
|||||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.dataset_fields import dataset_detail_fields
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
from services.hit_testing_service import HitTestingService
|
from services.hit_testing_service import HitTestingService
|
||||||
from services.knowledge_service import ExternalDatasetTestService
|
from services.knowledge_service import ExternalDatasetTestService
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name: str) -> str:
|
||||||
if not name or len(name) < 1 or len(name) > 100:
|
if not name or len(name) < 1 or len(name) > 100:
|
||||||
raise ValueError("Name must be between 1 to 100 characters.")
|
raise ValueError("Name must be between 1 to 100 characters.")
|
||||||
return name
|
return name
|
||||||
@ -37,12 +36,13 @@ class ExternalApiTemplateListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
page = request.args.get("page", default=1, type=int)
|
page = request.args.get("page", default=1, type=int)
|
||||||
limit = request.args.get("limit", default=20, type=int)
|
limit = request.args.get("limit", default=20, type=int)
|
||||||
search = request.args.get("keyword", default=None, type=str)
|
search = request.args.get("keyword", default=None, type=str)
|
||||||
|
|
||||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||||
page, limit, current_user.current_tenant_id, search
|
page, limit, current_tenant_id, search
|
||||||
)
|
)
|
||||||
response = {
|
response = {
|
||||||
"data": [item.to_dict() for item in external_knowledge_apis],
|
"data": [item.to_dict() for item in external_knowledge_apis],
|
||||||
@ -57,20 +57,23 @@ class ExternalApiTemplateListApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument(
|
parser = (
|
||||||
"name",
|
reqparse.RequestParser()
|
||||||
nullable=False,
|
.add_argument(
|
||||||
required=True,
|
"name",
|
||||||
help="Name is required. Name must be between 1 to 100 characters.",
|
nullable=False,
|
||||||
type=_validate_name,
|
required=True,
|
||||||
)
|
help="Name is required. Name must be between 1 to 100 characters.",
|
||||||
parser.add_argument(
|
type=_validate_name,
|
||||||
"settings",
|
)
|
||||||
type=dict,
|
.add_argument(
|
||||||
location="json",
|
"settings",
|
||||||
nullable=False,
|
type=dict,
|
||||||
required=True,
|
location="json",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -82,7 +85,7 @@ class ExternalApiTemplateListApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
||||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
|
tenant_id=current_tenant_id, user_id=current_user.id, args=args
|
||||||
)
|
)
|
||||||
except services.errors.dataset.DatasetNameDuplicateError:
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
raise DatasetNameDuplicateError()
|
raise DatasetNameDuplicateError()
|
||||||
@ -112,28 +115,31 @@ class ExternalApiTemplateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def patch(self, external_knowledge_api_id):
|
def patch(self, external_knowledge_api_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="type is required. Name must be between 1 to 100 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="type is required. Name must be between 1 to 100 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument(
|
)
|
||||||
"settings",
|
.add_argument(
|
||||||
type=dict,
|
"settings",
|
||||||
location="json",
|
type=dict,
|
||||||
nullable=False,
|
location="json",
|
||||||
required=True,
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
ExternalDatasetService.validate_api_list(args["settings"])
|
ExternalDatasetService.validate_api_list(args["settings"])
|
||||||
|
|
||||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
external_knowledge_api_id=external_knowledge_api_id,
|
external_knowledge_api_id=external_knowledge_api_id,
|
||||||
args=args,
|
args=args,
|
||||||
@ -145,13 +151,13 @@ class ExternalApiTemplateApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, external_knowledge_api_id):
|
def delete(self, external_knowledge_api_id):
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||||
|
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||||
if not (current_user.is_editor or current_user.is_dataset_operator):
|
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
|
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
@ -196,21 +202,24 @@ class ExternalDatasetCreateApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not current_user.is_editor:
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument(
|
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="name is required. Name must be between 1 to 100 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="name is required. Name must be between 1 to 100 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
|
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -220,7 +229,7 @@ class ExternalDatasetCreateApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = ExternalDatasetService.create_external_dataset(
|
dataset = ExternalDatasetService.create_external_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
@ -252,6 +261,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -262,10 +272,12 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||||||
except services.errors.account.NoPermissionError as e:
|
except services.errors.account.NoPermissionError as e:
|
||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("query", type=str, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
.add_argument("query", type=str, location="json")
|
||||||
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
HitTestingService.hit_testing_args_check(args)
|
||||||
@ -301,15 +313,17 @@ class BedrockRetrievalApi(Resource):
|
|||||||
)
|
)
|
||||||
@api.response(200, "Bedrock retrieval test completed")
|
@api.response(200, "Bedrock retrieval test completed")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument(
|
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||||
"query",
|
.add_argument(
|
||||||
nullable=False,
|
"query",
|
||||||
required=True,
|
nullable=False,
|
||||||
type=str,
|
required=True,
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||||
)
|
)
|
||||||
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Call the knowledge retrieval service
|
# Call the knowledge retrieval service
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import marshal, reqparse
|
from flask_restx import marshal, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
import services.dataset_service
|
import services
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
CompletionRequestError,
|
CompletionRequestError,
|
||||||
ProviderModelCurrentlyNotSupportError,
|
ProviderModelCurrentlyNotSupportError,
|
||||||
@ -20,6 +19,8 @@ from core.errors.error import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.hit_testing_fields import hit_testing_record_fields
|
from fields.hit_testing_fields import hit_testing_record_fields
|
||||||
|
from libs.login import current_user
|
||||||
|
from models.account import Account
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.hit_testing_service import HitTestingService
|
from services.hit_testing_service import HitTestingService
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class DatasetsHitTestingBase:
|
class DatasetsHitTestingBase:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_and_validate_dataset(dataset_id: str):
|
def get_and_validate_dataset(dataset_id: str):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
dataset = DatasetService.get_dataset(dataset_id)
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
@ -46,15 +48,17 @@ class DatasetsHitTestingBase:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser.add_argument("query", type=str, location="json")
|
.add_argument("query", type=str, location="json")
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def perform_hit_testing(dataset, args):
|
def perform_hit_testing(dataset, args):
|
||||||
|
assert isinstance(current_user, Account)
|
||||||
try:
|
try:
|
||||||
response = HitTestingService.retrieve(
|
response = HitTestingService.retrieve(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
|||||||
@ -1,13 +1,12 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||||
from fields.dataset_fields import dataset_metadata_fields
|
from fields.dataset_fields import dataset_metadata_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from services.entities.knowledge_entities.knowledge_entities import (
|
from services.entities.knowledge_entities.knowledge_entities import (
|
||||||
MetadataArgs,
|
MetadataArgs,
|
||||||
@ -16,6 +15,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
|||||||
from services.metadata_service import MetadataService
|
from services.metadata_service import MetadataService
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||||
class DatasetMetadataCreateApi(Resource):
|
class DatasetMetadataCreateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -23,11 +23,14 @@ class DatasetMetadataCreateApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
@marshal_with(dataset_metadata_fields)
|
@marshal_with(dataset_metadata_fields)
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
parser = reqparse.RequestParser()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
parser = (
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
|
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
metadata_args = MetadataArgs(**args)
|
metadata_args = MetadataArgs.model_validate(args)
|
||||||
|
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
@ -50,6 +53,7 @@ class DatasetMetadataCreateApi(Resource):
|
|||||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||||
class DatasetMetadataApi(Resource):
|
class DatasetMetadataApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -57,9 +61,10 @@ class DatasetMetadataApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
@marshal_with(dataset_metadata_fields)
|
@marshal_with(dataset_metadata_fields)
|
||||||
def patch(self, dataset_id, metadata_id):
|
def patch(self, dataset_id, metadata_id):
|
||||||
parser = reqparse.RequestParser()
|
current_user, _ = current_account_with_tenant()
|
||||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
name = args["name"]
|
||||||
|
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
metadata_id_str = str(metadata_id)
|
metadata_id_str = str(metadata_id)
|
||||||
@ -68,7 +73,7 @@ class DatasetMetadataApi(Resource):
|
|||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
|
||||||
return metadata, 200
|
return metadata, 200
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@ -76,6 +81,7 @@ class DatasetMetadataApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def delete(self, dataset_id, metadata_id):
|
def delete(self, dataset_id, metadata_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
metadata_id_str = str(metadata_id)
|
metadata_id_str = str(metadata_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
@ -87,6 +93,7 @@ class DatasetMetadataApi(Resource):
|
|||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/metadata/built-in")
|
||||||
class DatasetMetadataBuiltInFieldApi(Resource):
|
class DatasetMetadataBuiltInFieldApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -97,12 +104,14 @@ class DatasetMetadataBuiltInFieldApi(Resource):
|
|||||||
return {"fields": built_in_fields}, 200
|
return {"fields": built_in_fields}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||||
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
@ -116,30 +125,26 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||||
class DocumentMetadataEditApi(Resource):
|
class DocumentMetadataEditApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
|
"operation_data", type=list, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
metadata_args = MetadataOperationData(**args)
|
metadata_args = MetadataOperationData.model_validate(args)
|
||||||
|
|
||||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||||
|
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
|
|
||||||
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
|
||||||
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
|
|
||||||
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
|
||||||
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")
|
|
||||||
|
|||||||
@ -1,33 +1,30 @@
|
|||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from flask import make_response, redirect, request
|
from flask import make_response, redirect, request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||||
account_initialization_required,
|
|
||||||
setup_required,
|
|
||||||
)
|
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from libs.helper import StrLen
|
from libs.helper import StrLen
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.provider_ids import DatasourceProviderID
|
from models.provider_ids import DatasourceProviderID
|
||||||
from services.datasource_provider_service import DatasourceProviderService
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
from services.plugin.oauth_service import OAuthProxyService
|
from services.plugin.oauth_service import OAuthProxyService
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
||||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def get(self, provider_id: str):
|
def get(self, provider_id: str):
|
||||||
user = current_user
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
if not current_user.is_editor:
|
tenant_id = current_tenant_id
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
credential_id = request.args.get("credential_id")
|
credential_id = request.args.get("credential_id")
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
@ -51,7 +48,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||||
authorization_url_response = oauth_handler.get_authorization_url(
|
authorization_url_response = oauth_handler.get_authorization_url(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user.id,
|
user_id=current_user.id,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
redirect_uri=redirect_uri,
|
redirect_uri=redirect_uri,
|
||||||
@ -68,6 +65,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/callback")
|
||||||
class DatasourceOAuthCallback(Resource):
|
class DatasourceOAuthCallback(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
def get(self, provider_id: str):
|
def get(self, provider_id: str):
|
||||||
@ -123,26 +121,29 @@ class DatasourceOAuthCallback(Resource):
|
|||||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||||
class DatasourceAuth(Resource):
|
class DatasourceAuth(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
.add_argument(
|
||||||
|
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||||
|
)
|
||||||
|
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
datasource_provider_service.add_datasource_api_key_provider(
|
datasource_provider_service.add_datasource_api_key_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider_id=datasource_provider_id,
|
provider_id=datasource_provider_id,
|
||||||
credentials=args["credentials"],
|
credentials=args["credentials"],
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
@ -157,30 +158,36 @@ class DatasourceAuth(Resource):
|
|||||||
def get(self, provider_id: str):
|
def get(self, provider_id: str):
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasources = datasource_provider_service.list_datasource_credentials(
|
datasources = datasource_provider_service.list_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
provider=datasource_provider_id.provider_name,
|
provider=datasource_provider_id.provider_name,
|
||||||
plugin_id=datasource_provider_id.plugin_id,
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
)
|
)
|
||||||
return {"result": datasources}, 200
|
return {"result": datasources}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||||
class DatasourceAuthDeleteApi(Resource):
|
class DatasourceAuthDeleteApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
plugin_id = datasource_provider_id.plugin_id
|
plugin_id = datasource_provider_id.plugin_id
|
||||||
provider_name = datasource_provider_id.provider_name
|
provider_name = datasource_provider_id.provider_name
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser = reqparse.RequestParser()
|
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.remove_datasource_credentials(
|
datasource_provider_service.remove_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
auth_id=args["credential_id"],
|
auth_id=args["credential_id"],
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
@ -188,22 +195,27 @@ class DatasourceAuthDeleteApi(Resource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||||
class DatasourceAuthUpdateApi(Resource):
|
class DatasourceAuthUpdateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not current_user.is_editor:
|
|
||||||
raise Forbidden()
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.update_datasource_credentials(
|
datasource_provider_service.update_datasource_credentials(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
auth_id=args["credential_id"],
|
auth_id=args["credential_id"],
|
||||||
provider=datasource_provider_id.provider_name,
|
provider=datasource_provider_id.provider_name,
|
||||||
plugin_id=datasource_provider_id.plugin_id,
|
plugin_id=datasource_provider_id.plugin_id,
|
||||||
@ -213,45 +225,51 @@ class DatasourceAuthUpdateApi(Resource):
|
|||||||
return {"result": "success"}, 201
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/auth/plugin/datasource/list")
|
||||||
class DatasourceAuthListApi(Resource):
|
class DatasourceAuthListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
return {"result": jsonable_encoder(datasources)}, 200
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/auth/plugin/datasource/default-list")
|
||||||
class DatasourceHardCodeAuthListApi(Resource):
|
class DatasourceHardCodeAuthListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||||
tenant_id=current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
return {"result": jsonable_encoder(datasources)}, 200
|
return {"result": jsonable_encoder(datasources)}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||||
class DatasourceAuthOauthCustomClient(Resource):
|
class DatasourceAuthOauthCustomClient(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||||
|
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.setup_oauth_custom_client_params(
|
datasource_provider_service.setup_oauth_custom_client_params(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
client_params=args.get("client_params", {}),
|
client_params=args.get("client_params", {}),
|
||||||
enabled=args.get("enable_oauth_custom_client", False),
|
enabled=args.get("enable_oauth_custom_client", False),
|
||||||
@ -262,101 +280,59 @@ class DatasourceAuthOauthCustomClient(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, provider_id: str):
|
def delete(self, provider_id: str):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.remove_oauth_custom_client_params(
|
datasource_provider_service.remove_oauth_custom_client_params(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||||
class DatasourceAuthDefaultApi(Resource):
|
class DatasourceAuthDefaultApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.set_default_datasource_provider(
|
datasource_provider_service.set_default_datasource_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
credential_id=args["id"],
|
credential_id=args["id"],
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||||
class DatasourceUpdateProviderNameApi(Resource):
|
class DatasourceUpdateProviderNameApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
|
@edit_permission_required
|
||||||
def post(self, provider_id: str):
|
def post(self, provider_id: str):
|
||||||
if not current_user.is_editor:
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||||
datasource_provider_service = DatasourceProviderService()
|
datasource_provider_service = DatasourceProviderService()
|
||||||
datasource_provider_service.update_datasource_provider_name(
|
datasource_provider_service.update_datasource_provider_name(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
datasource_provider_id=datasource_provider_id,
|
datasource_provider_id=datasource_provider_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
credential_id=args["credential_id"],
|
credential_id=args["credential_id"],
|
||||||
)
|
)
|
||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DatasourcePluginOAuthAuthorizationUrl,
|
|
||||||
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceOAuthCallback,
|
|
||||||
"/oauth/plugin/<path:provider_id>/datasource/callback",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceAuth,
|
|
||||||
"/auth/plugin/datasource/<path:provider_id>",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceAuthUpdateApi,
|
|
||||||
"/auth/plugin/datasource/<path:provider_id>/update",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceAuthDeleteApi,
|
|
||||||
"/auth/plugin/datasource/<path:provider_id>/delete",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceAuthListApi,
|
|
||||||
"/auth/plugin/datasource/list",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceHardCodeAuthListApi,
|
|
||||||
"/auth/plugin/datasource/default-list",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceAuthOauthCustomClient,
|
|
||||||
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceAuthDefaultApi,
|
|
||||||
"/auth/plugin/datasource/<path:provider_id>/default",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceUpdateProviderNameApi,
|
|
||||||
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
|
||||||
)
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
|
|||||||
)
|
)
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
@ -12,8 +12,17 @@ from models import Account
|
|||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||||
class DataSourceContentPreviewApi(Resource):
|
class DataSourceContentPreviewApi(Resource):
|
||||||
|
@api.expect(parser)
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@ -25,10 +34,6 @@ class DataSourceContentPreviewApi(Resource):
|
|||||||
if not isinstance(current_user, Account):
|
if not isinstance(current_user, Account):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
|
||||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
@ -49,9 +54,3 @@ class DataSourceContentPreviewApi(Resource):
|
|||||||
credential_id=args.get("credential_id"),
|
credential_id=args.get("credential_id"),
|
||||||
)
|
)
|
||||||
return preview_content, 200
|
return preview_content, 200
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DataSourceContentPreviewApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
|
||||||
)
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from flask import request
|
|||||||
from flask_restx import Resource, reqparse
|
from flask_restx import Resource, reqparse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
enterprise_license_required,
|
enterprise_license_required,
|
||||||
@ -20,18 +20,19 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
def _validate_name(name: str) -> str:
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
raise ValueError("Name must be between 1 to 40 characters.")
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def _validate_description_length(description):
|
def _validate_description_length(description: str) -> str:
|
||||||
if len(description) > 400:
|
if len(description) > 400:
|
||||||
raise ValueError("Description cannot exceed 400 characters.")
|
raise ValueError("Description cannot exceed 400 characters.")
|
||||||
return description
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipeline/templates")
|
||||||
class PipelineTemplateListApi(Resource):
|
class PipelineTemplateListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -45,6 +46,7 @@ class PipelineTemplateListApi(Resource):
|
|||||||
return pipeline_templates, 200
|
return pipeline_templates, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
|
||||||
class PipelineTemplateDetailApi(Resource):
|
class PipelineTemplateDetailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -57,35 +59,38 @@ class PipelineTemplateDetailApi(Resource):
|
|||||||
return pipeline_template, 200
|
return pipeline_template, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||||
class CustomizedPipelineTemplateApi(Resource):
|
class CustomizedPipelineTemplateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
def patch(self, template_id: str):
|
def patch(self, template_id: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="Name must be between 1 to 40 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="Name must be between 1 to 40 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument(
|
)
|
||||||
"description",
|
.add_argument(
|
||||||
type=str,
|
"description",
|
||||||
nullable=True,
|
type=_validate_description_length,
|
||||||
required=False,
|
nullable=True,
|
||||||
default="",
|
required=False,
|
||||||
)
|
default="",
|
||||||
parser.add_argument(
|
)
|
||||||
"icon_info",
|
.add_argument(
|
||||||
type=dict,
|
"icon_info",
|
||||||
location="json",
|
type=dict,
|
||||||
nullable=True,
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
||||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||||
return 200
|
return 200
|
||||||
|
|
||||||
@ -112,6 +117,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||||||
return {"data": template.yaml_content}, 200
|
return {"data": template.yaml_content}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -119,46 +125,30 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
|||||||
@enterprise_license_required
|
@enterprise_license_required
|
||||||
@knowledge_pipeline_publish_enabled
|
@knowledge_pipeline_publish_enabled
|
||||||
def post(self, pipeline_id: str):
|
def post(self, pipeline_id: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"name",
|
.add_argument(
|
||||||
nullable=False,
|
"name",
|
||||||
required=True,
|
nullable=False,
|
||||||
help="Name must be between 1 to 40 characters.",
|
required=True,
|
||||||
type=_validate_name,
|
help="Name must be between 1 to 40 characters.",
|
||||||
)
|
type=_validate_name,
|
||||||
parser.add_argument(
|
)
|
||||||
"description",
|
.add_argument(
|
||||||
type=str,
|
"description",
|
||||||
nullable=True,
|
type=_validate_description_length,
|
||||||
required=False,
|
nullable=True,
|
||||||
default="",
|
required=False,
|
||||||
)
|
default="",
|
||||||
parser.add_argument(
|
)
|
||||||
"icon_info",
|
.add_argument(
|
||||||
type=dict,
|
"icon_info",
|
||||||
location="json",
|
type=dict,
|
||||||
nullable=True,
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
PipelineTemplateListApi,
|
|
||||||
"/rag/pipeline/templates",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
PipelineTemplateDetailApi,
|
|
||||||
"/rag/pipeline/templates/<string:template_id>",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
CustomizedPipelineTemplateApi,
|
|
||||||
"/rag/pipeline/customized/templates/<string:template_id>",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
PublishCustomizedPipelineTemplateApi,
|
|
||||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from flask_login import current_user # type: ignore # type: ignore
|
from flask_restx import Resource, marshal, reqparse
|
||||||
from flask_restx import Resource, marshal, reqparse # type: ignore
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
@ -13,34 +12,21 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.dataset_fields import dataset_detail_fields
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.dataset import DatasetPermissionEnum
|
from models.dataset import DatasetPermissionEnum
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
def _validate_name(name):
|
@console_ns.route("/rag/pipeline/dataset")
|
||||||
if not name or len(name) < 1 or len(name) > 40:
|
|
||||||
raise ValueError("Name must be between 1 to 40 characters.")
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_description_length(description):
|
|
||||||
if len(description) > 400:
|
|
||||||
raise ValueError("Description cannot exceed 400 characters.")
|
|
||||||
return description
|
|
||||||
|
|
||||||
|
|
||||||
class CreateRagPipelineDatasetApi(Resource):
|
class CreateRagPipelineDatasetApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"yaml_content",
|
"yaml_content",
|
||||||
type=str,
|
type=str,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
@ -49,7 +35,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
@ -69,12 +55,12 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||||
)
|
)
|
||||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||||
DatasetPermissionService.update_partial_member_list(
|
DatasetPermissionService.update_partial_member_list(
|
||||||
current_user.current_tenant_id,
|
current_tenant_id,
|
||||||
import_info["dataset_id"],
|
import_info["dataset_id"],
|
||||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||||
)
|
)
|
||||||
@ -84,6 +70,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
|||||||
return import_info, 201
|
return import_info, 201
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipeline/empty-dataset")
|
||||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -91,10 +78,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
|||||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
def post(self):
|
def post(self):
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
if not current_user.is_dataset_editor:
|
if not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||||
name="",
|
name="",
|
||||||
description="",
|
description="",
|
||||||
@ -108,7 +97,3 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
return marshal(dataset, dataset_detail_fields), 201
|
return marshal(dataset, dataset_detail_fields), 201
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
|
||||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
|
||||||
|
|||||||
@ -1,31 +1,29 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, NoReturn
|
from typing import NoReturn
|
||||||
|
|
||||||
from flask import Response
|
from flask import Response
|
||||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
DraftWorkflowNotExist,
|
DraftWorkflowNotExist,
|
||||||
)
|
)
|
||||||
from controllers.console.app.workflow_draft_variable import (
|
from controllers.console.app.workflow_draft_variable import (
|
||||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
_WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage]
|
||||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage]
|
||||||
)
|
)
|
||||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||||
from core.variables.segment_group import SegmentGroup
|
|
||||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
|
||||||
from core.variables.types import SegmentType
|
from core.variables.types import SegmentType
|
||||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||||
from factories.variable_factory import build_segment_with_type
|
from factories.variable_factory import build_segment_with_type
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from models.workflow import WorkflowDraftVariable
|
from models.workflow import WorkflowDraftVariable
|
||||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
@ -34,43 +32,19 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList,
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
|
||||||
if isinstance(value, FileSegment):
|
|
||||||
return value.value.model_dump()
|
|
||||||
elif isinstance(value, ArrayFileSegment):
|
|
||||||
return [i.model_dump() for i in value.value]
|
|
||||||
elif isinstance(value, SegmentGroup):
|
|
||||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
|
||||||
else:
|
|
||||||
return value.value
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
|
||||||
value = variable.get_value()
|
|
||||||
# create a copy of the value to avoid affecting the model cache.
|
|
||||||
value = value.model_copy(deep=True)
|
|
||||||
# Refresh the url signature before returning it to client.
|
|
||||||
if isinstance(value, FileSegment):
|
|
||||||
file = value.value
|
|
||||||
file.remote_url = file.generate_url()
|
|
||||||
elif isinstance(value, ArrayFileSegment):
|
|
||||||
files = value.value
|
|
||||||
for file in files:
|
|
||||||
file.remote_url = file.generate_url()
|
|
||||||
return _convert_values_to_json_serializable_object(value)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_pagination_parser():
|
def _create_pagination_parser():
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"page",
|
.add_argument(
|
||||||
type=inputs.int_range(1, 100_000),
|
"page",
|
||||||
required=False,
|
type=inputs.int_range(1, 100_000),
|
||||||
default=1,
|
required=False,
|
||||||
location="args",
|
default=1,
|
||||||
help="the page of data requested",
|
location="args",
|
||||||
|
help="the page of data requested",
|
||||||
|
)
|
||||||
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
)
|
)
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -104,13 +78,14 @@ def _api_prerequisite(f):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables")
|
||||||
class RagPipelineVariableCollectionApi(Resource):
|
class RagPipelineVariableCollectionApi(Resource):
|
||||||
@_api_prerequisite
|
@_api_prerequisite
|
||||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||||
@ -168,6 +143,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||||
@_api_prerequisite
|
@_api_prerequisite
|
||||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
@ -190,6 +166,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
|||||||
return Response("", 204)
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||||
class RagPipelineVariableApi(Resource):
|
class RagPipelineVariableApi(Resource):
|
||||||
_PATCH_NAME_FIELD = "name"
|
_PATCH_NAME_FIELD = "name"
|
||||||
_PATCH_VALUE_FIELD = "value"
|
_PATCH_VALUE_FIELD = "value"
|
||||||
@ -231,10 +208,11 @@ class RagPipelineVariableApi(Resource):
|
|||||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||||
# }
|
# }
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
reqparse.RequestParser()
|
||||||
# Parse 'value' field as-is to maintain its original data structure
|
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||||
|
)
|
||||||
|
|
||||||
draft_var_srv = WorkflowDraftVariableService(
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
session=db.session(),
|
session=db.session(),
|
||||||
@ -284,6 +262,7 @@ class RagPipelineVariableApi(Resource):
|
|||||||
return Response("", 204)
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||||
class RagPipelineVariableResetApi(Resource):
|
class RagPipelineVariableResetApi(Resource):
|
||||||
@_api_prerequisite
|
@_api_prerequisite
|
||||||
def put(self, pipeline: Pipeline, variable_id: str):
|
def put(self, pipeline: Pipeline, variable_id: str):
|
||||||
@ -325,6 +304,7 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
|||||||
return draft_vars
|
return draft_vars
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables")
|
||||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||||
@_api_prerequisite
|
@_api_prerequisite
|
||||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
@ -332,6 +312,7 @@ class RagPipelineSystemVariableCollectionApi(Resource):
|
|||||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
|
||||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||||
@_api_prerequisite
|
@_api_prerequisite
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
@ -364,26 +345,3 @@ class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {"items": env_vars_list}
|
return {"items": env_vars_list}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineVariableCollectionApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineNodeVariableCollectionApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineEnvironmentVariableCollectionApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,11 +1,8 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
from flask_login import current_user # type: ignore
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
@ -13,13 +10,13 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account
|
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from services.app_dsl_service import ImportStatus
|
from services.app_dsl_service import ImportStatus
|
||||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/imports")
|
||||||
class RagPipelineImportApi(Resource):
|
class RagPipelineImportApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -27,26 +24,29 @@ class RagPipelineImportApi(Resource):
|
|||||||
@marshal_with(pipeline_import_fields)
|
@marshal_with(pipeline_import_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("mode", type=str, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("yaml_content", type=str, location="json")
|
.add_argument("mode", type=str, required=True, location="json")
|
||||||
parser.add_argument("yaml_url", type=str, location="json")
|
.add_argument("yaml_content", type=str, location="json")
|
||||||
parser.add_argument("name", type=str, location="json")
|
.add_argument("yaml_url", type=str, location="json")
|
||||||
parser.add_argument("description", type=str, location="json")
|
.add_argument("name", type=str, location="json")
|
||||||
parser.add_argument("icon_type", type=str, location="json")
|
.add_argument("description", type=str, location="json")
|
||||||
parser.add_argument("icon", type=str, location="json")
|
.add_argument("icon_type", type=str, location="json")
|
||||||
parser.add_argument("icon_background", type=str, location="json")
|
.add_argument("icon", type=str, location="json")
|
||||||
parser.add_argument("pipeline_id", type=str, location="json")
|
.add_argument("icon_background", type=str, location="json")
|
||||||
|
.add_argument("pipeline_id", type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = RagPipelineDslService(session)
|
import_service = RagPipelineDslService(session)
|
||||||
# Import app
|
# Import app
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.import_rag_pipeline(
|
result = import_service.import_rag_pipeline(
|
||||||
account=account,
|
account=account,
|
||||||
import_mode=args["mode"],
|
import_mode=args["mode"],
|
||||||
@ -59,37 +59,40 @@ class RagPipelineImportApi(Resource):
|
|||||||
|
|
||||||
# Return appropriate status code based on result
|
# Return appropriate status code based on result
|
||||||
status = result.status
|
status = result.status
|
||||||
if status == ImportStatus.FAILED.value:
|
if status == ImportStatus.FAILED:
|
||||||
return result.model_dump(mode="json"), 400
|
return result.model_dump(mode="json"), 400
|
||||||
elif status == ImportStatus.PENDING.value:
|
elif status == ImportStatus.PENDING:
|
||||||
return result.model_dump(mode="json"), 202
|
return result.model_dump(mode="json"), 202
|
||||||
return result.model_dump(mode="json"), 200
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||||
class RagPipelineImportConfirmApi(Resource):
|
class RagPipelineImportConfirmApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(pipeline_import_fields)
|
@marshal_with(pipeline_import_fields)
|
||||||
def post(self, import_id):
|
def post(self, import_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
# Check user role first
|
# Check user role first
|
||||||
if not current_user.is_editor:
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# Create service with session
|
# Create service with session
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
import_service = RagPipelineDslService(session)
|
import_service = RagPipelineDslService(session)
|
||||||
# Confirm import
|
# Confirm import
|
||||||
account = cast(Account, current_user)
|
account = current_user
|
||||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
# Return appropriate status code based on result
|
# Return appropriate status code based on result
|
||||||
if result.status == ImportStatus.FAILED.value:
|
if result.status == ImportStatus.FAILED:
|
||||||
return result.model_dump(mode="json"), 400
|
return result.model_dump(mode="json"), 400
|
||||||
return result.model_dump(mode="json"), 200
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
|
||||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -97,7 +100,8 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
@ -107,18 +111,19 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
|||||||
return result.model_dump(mode="json"), 200
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
|
||||||
class RagPipelineExportApi(Resource):
|
class RagPipelineExportApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
if not current_user.is_editor:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# Add include_secret params
|
# Add include_secret params
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
|
||||||
parser.add_argument("include_secret", type=str, default="false", location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
@ -128,22 +133,3 @@ class RagPipelineExportApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {"data": result}, 200
|
return {"data": result}, 200
|
||||||
|
|
||||||
|
|
||||||
# Import Rag Pipeline
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineImportApi,
|
|
||||||
"/rag/pipelines/imports",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineImportConfirmApi,
|
|
||||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineImportCheckDependenciesApi,
|
|
||||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineExportApi,
|
|
||||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
|
||||||
)
|
|
||||||
|
|||||||
@ -9,8 +9,7 @@ from sqlalchemy.orm import Session
|
|||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from controllers.console import console_ns
|
||||||
from controllers.console import api
|
|
||||||
from controllers.console.app.error import (
|
from controllers.console.app.error import (
|
||||||
ConversationCompletedError,
|
ConversationCompletedError,
|
||||||
DraftWorkflowNotExist,
|
DraftWorkflowNotExist,
|
||||||
@ -19,6 +18,7 @@ from controllers.console.app.error import (
|
|||||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
from controllers.console.wraps import (
|
from controllers.console.wraps import (
|
||||||
account_initialization_required,
|
account_initialization_required,
|
||||||
|
edit_permission_required,
|
||||||
setup_required,
|
setup_required,
|
||||||
)
|
)
|
||||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||||
@ -37,8 +37,8 @@ from fields.workflow_run_fields import (
|
|||||||
)
|
)
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, current_user, login_required
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
from services.errors.app import WorkflowHashNotEqualError
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
@ -51,20 +51,18 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
|
||||||
class DraftRagPipelineApi(Resource):
|
class DraftRagPipelineApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(workflow_fields)
|
@marshal_with(workflow_fields)
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get draft rag pipeline's workflow
|
Get draft rag pipeline's workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
# fetch draft workflow by app_model
|
# fetch draft workflow by app_model
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||||
@ -79,23 +77,25 @@ class DraftRagPipelineApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def post(self, pipeline: Pipeline):
|
def post(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Sync draft workflow
|
Sync draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
|
||||||
if "application/json" in content_type:
|
if "application/json" in content_type:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("hash", type=str, required=False, location="json")
|
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("environment_variables", type=list, required=False, location="json")
|
.add_argument("hash", type=str, required=False, location="json")
|
||||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
.add_argument("environment_variables", type=list, required=False, location="json")
|
||||||
parser.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
|
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||||
|
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
elif "text/plain" in content_type:
|
elif "text/plain" in content_type:
|
||||||
try:
|
try:
|
||||||
@ -148,21 +148,21 @@ class DraftRagPipelineApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def post(self, pipeline: Pipeline, node_id: str):
|
def post(self, pipeline: Pipeline, node_id: str):
|
||||||
"""
|
"""
|
||||||
Run draft workflow iteration node
|
Run draft workflow iteration node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -182,6 +182,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
|||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -192,11 +193,11 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||||||
Run draft workflow loop node
|
Run draft workflow loop node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||||
parser.add_argument("inputs", type=dict, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -216,6 +217,7 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
|||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||||
class DraftRagPipelineRunApi(Resource):
|
class DraftRagPipelineRunApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -226,14 +228,17 @@ class DraftRagPipelineRunApi(Resource):
|
|||||||
Run draft workflow
|
Run draft workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||||
|
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -250,6 +255,7 @@ class DraftRagPipelineRunApi(Resource):
|
|||||||
raise InvokeRateLimitHttpError(ex.description)
|
raise InvokeRateLimitHttpError(ex.description)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||||
class PublishedRagPipelineRunApi(Resource):
|
class PublishedRagPipelineRunApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -260,17 +266,20 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
Run published workflow
|
Run published workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("datasource_info_list", type=list, required=True, location="json")
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||||
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||||
parser.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||||
|
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = args["response_mode"] == "streaming"
|
||||||
@ -299,15 +308,16 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
# Run rag pipeline datasource
|
# Run rag pipeline datasource
|
||||||
# """
|
# """
|
||||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||||
# if not current_user.is_editor:
|
# if not current_user.has_edit_permission:
|
||||||
# raise Forbidden()
|
# raise Forbidden()
|
||||||
#
|
#
|
||||||
# if not isinstance(current_user, Account):
|
# if not isinstance(current_user, Account):
|
||||||
# raise Forbidden()
|
# raise Forbidden()
|
||||||
#
|
#
|
||||||
# parser = reqparse.RequestParser()
|
# parser = (reqparse.RequestParser()
|
||||||
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||||
# parser.add_argument("datasource_type", type=str, required=True, location="json")
|
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
# )
|
||||||
# args = parser.parse_args()
|
# args = parser.parse_args()
|
||||||
#
|
#
|
||||||
# job_id = args.get("job_id")
|
# job_id = args.get("job_id")
|
||||||
@ -340,15 +350,16 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
# Run rag pipeline datasource
|
# Run rag pipeline datasource
|
||||||
# """
|
# """
|
||||||
# # The role of the current user in the ta table must be admin, owner, or editor
|
# # The role of the current user in the ta table must be admin, owner, or editor
|
||||||
# if not current_user.is_editor:
|
# if not current_user.has_edit_permission:
|
||||||
# raise Forbidden()
|
# raise Forbidden()
|
||||||
#
|
#
|
||||||
# if not isinstance(current_user, Account):
|
# if not isinstance(current_user, Account):
|
||||||
# raise Forbidden()
|
# raise Forbidden()
|
||||||
#
|
#
|
||||||
# parser = reqparse.RequestParser()
|
# parser = (reqparse.RequestParser()
|
||||||
# parser.add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
|
||||||
# parser.add_argument("datasource_type", type=str, required=True, location="json")
|
# .add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
# )
|
||||||
# args = parser.parse_args()
|
# args = parser.parse_args()
|
||||||
#
|
#
|
||||||
# job_id = args.get("job_id")
|
# job_id = args.get("job_id")
|
||||||
@ -370,6 +381,7 @@ class PublishedRagPipelineRunApi(Resource):
|
|||||||
#
|
#
|
||||||
# return result
|
# return result
|
||||||
#
|
#
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -380,13 +392,16 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||||||
Run rag pipeline datasource
|
Run rag pipeline datasource
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
@ -412,6 +427,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -422,13 +438,16 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||||||
Run rag pipeline datasource
|
Run rag pipeline datasource
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
.add_argument("credential_id", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
@ -454,6 +473,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||||
class RagPipelineDraftNodeRunApi(Resource):
|
class RagPipelineDraftNodeRunApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -465,11 +485,13 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||||||
Run draft workflow node
|
Run draft workflow node
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
"inputs", type=dict, required=True, nullable=False, location="json"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
inputs = args.get("inputs")
|
inputs = args.get("inputs")
|
||||||
@ -487,6 +509,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
|||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||||
class RagPipelineTaskStopApi(Resource):
|
class RagPipelineTaskStopApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -497,7 +520,8 @@ class RagPipelineTaskStopApi(Resource):
|
|||||||
Stop workflow task
|
Stop workflow task
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||||
@ -505,6 +529,7 @@ class RagPipelineTaskStopApi(Resource):
|
|||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/publish")
|
||||||
class PublishedRagPipelineApi(Resource):
|
class PublishedRagPipelineApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -516,7 +541,8 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
Get published pipeline
|
Get published pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
if not pipeline.is_published:
|
if not pipeline.is_published:
|
||||||
return None
|
return None
|
||||||
@ -536,7 +562,8 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
Publish workflow
|
Publish workflow
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
@ -560,6 +587,7 @@ class PublishedRagPipelineApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs")
|
||||||
class DefaultRagPipelineBlockConfigsApi(Resource):
|
class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -570,7 +598,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
# Get default block configs
|
# Get default block configs
|
||||||
@ -578,6 +607,7 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
|||||||
return rag_pipeline_service.get_default_block_configs()
|
return rag_pipeline_service.get_default_block_configs()
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -588,11 +618,11 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||||||
Get default block config
|
Get default block config
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
# The role of the current user in the ta table must be admin, owner, or editor
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||||
parser.add_argument("q", type=str, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
q = args.get("q")
|
q = args.get("q")
|
||||||
@ -609,18 +639,7 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
|||||||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||||
|
|
||||||
|
|
||||||
class RagPipelineConfigApi(Resource):
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||||
"""Resource for rag pipeline configuration."""
|
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def get(self, pipeline_id):
|
|
||||||
return {
|
|
||||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PublishedAllRagPipelineApi(Resource):
|
class PublishedAllRagPipelineApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -631,14 +650,17 @@ class PublishedAllRagPipelineApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get published workflows
|
Get published workflows
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||||
parser.add_argument("user_id", type=str, required=False, location="args")
|
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
.add_argument("user_id", type=str, required=False, location="args")
|
||||||
|
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
page = int(args.get("page", 1))
|
page = int(args.get("page", 1))
|
||||||
limit = int(args.get("limit", 10))
|
limit = int(args.get("limit", 10))
|
||||||
@ -669,6 +691,7 @@ class PublishedAllRagPipelineApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||||
class RagPipelineByIdApi(Resource):
|
class RagPipelineByIdApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -680,12 +703,15 @@ class RagPipelineByIdApi(Resource):
|
|||||||
Update workflow attributes
|
Update workflow attributes
|
||||||
"""
|
"""
|
||||||
# Check permission
|
# Check permission
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
|
if not current_user.has_edit_permission:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
.add_argument("marked_name", type=str, required=False, location="json")
|
||||||
|
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Validate name and comment length
|
# Validate name and comment length
|
||||||
@ -726,20 +752,18 @@ class RagPipelineByIdApi(Resource):
|
|||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||||
class PublishedRagPipelineSecondStepApi(Resource):
|
class PublishedRagPipelineSecondStepApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
node_id = args.get("node_id")
|
node_id = args.get("node_id")
|
||||||
if not node_id:
|
if not node_id:
|
||||||
@ -751,20 +775,18 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||||
class PublishedRagPipelineFirstStepApi(Resource):
|
class PublishedRagPipelineFirstStepApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get first step parameters of rag pipeline
|
Get first step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
node_id = args.get("node_id")
|
node_id = args.get("node_id")
|
||||||
if not node_id:
|
if not node_id:
|
||||||
@ -776,20 +798,18 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||||
class DraftRagPipelineFirstStepApi(Resource):
|
class DraftRagPipelineFirstStepApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get first step parameters of rag pipeline
|
Get first step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
node_id = args.get("node_id")
|
node_id = args.get("node_id")
|
||||||
if not node_id:
|
if not node_id:
|
||||||
@ -801,20 +821,18 @@ class DraftRagPipelineFirstStepApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||||
class DraftRagPipelineSecondStepApi(Resource):
|
class DraftRagPipelineSecondStepApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
def get(self, pipeline: Pipeline):
|
def get(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Get second step parameters of rag pipeline
|
Get second step parameters of rag pipeline
|
||||||
"""
|
"""
|
||||||
# The role of the current user in the ta table must be admin, owner, or editor
|
parser = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
|
||||||
raise Forbidden()
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("node_id", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
node_id = args.get("node_id")
|
node_id = args.get("node_id")
|
||||||
if not node_id:
|
if not node_id:
|
||||||
@ -827,6 +845,7 @@ class DraftRagPipelineSecondStepApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||||
class RagPipelineWorkflowRunListApi(Resource):
|
class RagPipelineWorkflowRunListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -837,9 +856,11 @@ class RagPipelineWorkflowRunListApi(Resource):
|
|||||||
"""
|
"""
|
||||||
Get workflow run list
|
Get workflow run list
|
||||||
"""
|
"""
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("last_id", type=uuid_value, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
@ -848,6 +869,7 @@ class RagPipelineWorkflowRunListApi(Resource):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>")
|
||||||
class RagPipelineWorkflowRunDetailApi(Resource):
|
class RagPipelineWorkflowRunDetailApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -866,13 +888,14 @@ class RagPipelineWorkflowRunDetailApi(Resource):
|
|||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||||
class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
@marshal_with(workflow_run_node_execution_list_fields)
|
@marshal_with(workflow_run_node_execution_list_fields)
|
||||||
def get(self, pipeline: Pipeline, run_id):
|
def get(self, pipeline: Pipeline, run_id: str):
|
||||||
"""
|
"""
|
||||||
Get workflow run node execution list
|
Get workflow run node execution list
|
||||||
"""
|
"""
|
||||||
@ -889,21 +912,17 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
|||||||
return {"data": node_executions}
|
return {"data": node_executions}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/datasource-plugins")
|
||||||
class DatasourceListApi(Resource):
|
class DatasourceListApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
user = current_user
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
if not isinstance(user, Account):
|
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(current_tenant_id))
|
||||||
raise Forbidden()
|
|
||||||
tenant_id = user.current_tenant_id
|
|
||||||
if not tenant_id:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
return jsonable_encoder(RagPipelineManageService.list_rag_pipeline_datasources(tenant_id))
|
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||||
class RagPipelineWorkflowLastRunApi(Resource):
|
class RagPipelineWorkflowLastRunApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -925,13 +944,13 @@ class RagPipelineWorkflowLastRunApi(Resource):
|
|||||||
return node_exec
|
return node_exec
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/transform/datasets/<uuid:dataset_id>")
|
||||||
class RagPipelineTransformApi(Resource):
|
class RagPipelineTransformApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id: str):
|
||||||
if not isinstance(current_user, Account):
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
@ -942,24 +961,26 @@ class RagPipelineTransformApi(Resource):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||||
class RagPipelineDatasourceVariableApi(Resource):
|
class RagPipelineDatasourceVariableApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@get_rag_pipeline
|
@get_rag_pipeline
|
||||||
|
@edit_permission_required
|
||||||
@marshal_with(workflow_run_node_execution_fields)
|
@marshal_with(workflow_run_node_execution_fields)
|
||||||
def post(self, pipeline: Pipeline):
|
def post(self, pipeline: Pipeline):
|
||||||
"""
|
"""
|
||||||
Set datasource variables
|
Set datasource variables
|
||||||
"""
|
"""
|
||||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
current_user, _ = current_account_with_tenant()
|
||||||
raise Forbidden()
|
parser = (
|
||||||
|
reqparse.RequestParser()
|
||||||
parser = reqparse.RequestParser()
|
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
.add_argument("datasource_info", type=dict, required=True, location="json")
|
||||||
parser.add_argument("datasource_info", type=dict, required=True, location="json")
|
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||||
parser.add_argument("start_node_id", type=str, required=True, location="json")
|
.add_argument("start_node_title", type=str, required=True, location="json")
|
||||||
parser.add_argument("start_node_title", type=str, required=True, location="json")
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
@ -971,6 +992,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
|||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/rag/pipelines/recommended-plugins")
|
||||||
class RagPipelineRecommendedPluginApi(Resource):
|
class RagPipelineRecommendedPluginApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -979,118 +1001,3 @@ class RagPipelineRecommendedPluginApi(Resource):
|
|||||||
rag_pipeline_service = RagPipelineService()
|
rag_pipeline_service = RagPipelineService()
|
||||||
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
|
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
|
||||||
return recommended_plugins
|
return recommended_plugins
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
DraftRagPipelineApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineConfigApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/config",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
DraftRagPipelineRunApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
PublishedRagPipelineRunApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/run",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineTaskStopApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineDraftNodeRunApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelinePublishedDatasourceNodeRunApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineDraftDatasourceNodeRunApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineDraftRunIterationNodeApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineDraftRunLoopNodeApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
PublishedRagPipelineApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/publish",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
PublishedAllRagPipelineApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
DefaultRagPipelineBlockConfigsApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
DefaultRagPipelineBlockConfigApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineByIdApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineWorkflowRunListApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineWorkflowRunDetailApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineWorkflowRunNodeExecutionListApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflow-runs/<uuid:run_id>/node-executions",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
DatasourceListApi,
|
|
||||||
"/rag/pipelines/datasource-plugins",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
PublishedRagPipelineSecondStepApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
PublishedRagPipelineFirstStepApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
DraftRagPipelineSecondStepApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
DraftRagPipelineFirstStepApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineWorkflowLastRunApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/last-run",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineTransformApi,
|
|
||||||
"/rag/pipelines/transform/datasets/<uuid:dataset_id>",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineDatasourceVariableApi,
|
|
||||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect",
|
|
||||||
)
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
RagPipelineRecommendedPluginApi,
|
|
||||||
"/rag/pipelines/recommended-plugins",
|
|
||||||
)
|
|
||||||
|
|||||||
@ -31,17 +31,19 @@ class WebsiteCrawlApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument(
|
reqparse.RequestParser()
|
||||||
"provider",
|
.add_argument(
|
||||||
type=str,
|
"provider",
|
||||||
choices=["firecrawl", "watercrawl", "jinareader"],
|
type=str,
|
||||||
required=True,
|
choices=["firecrawl", "watercrawl", "jinareader"],
|
||||||
nullable=True,
|
required=True,
|
||||||
location="json",
|
nullable=True,
|
||||||
|
location="json",
|
||||||
|
)
|
||||||
|
.add_argument("url", type=str, required=True, nullable=True, location="json")
|
||||||
|
.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
||||||
)
|
)
|
||||||
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
|
|
||||||
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create typed request and validate
|
# Create typed request and validate
|
||||||
@ -70,8 +72,7 @@ class WebsiteCrawlStatusApi(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, job_id: str):
|
def get(self, job_id: str):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument(
|
|
||||||
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -3,8 +3,7 @@ from functools import wraps
|
|||||||
|
|
||||||
from controllers.console.datasets.error import PipelineNotFoundError
|
from controllers.console.datasets.error import PipelineNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models.account import Account
|
|
||||||
from models.dataset import Pipeline
|
from models.dataset import Pipeline
|
||||||
|
|
||||||
|
|
||||||
@ -17,8 +16,7 @@ def get_rag_pipeline(
|
|||||||
if not kwargs.get("pipeline_id"):
|
if not kwargs.get("pipeline_id"):
|
||||||
raise ValueError("missing pipeline_id in path parameters")
|
raise ValueError("missing pipeline_id in path parameters")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user is not an account")
|
|
||||||
|
|
||||||
pipeline_id = kwargs.get("pipeline_id")
|
pipeline_id = kwargs.get("pipeline_id")
|
||||||
pipeline_id = str(pipeline_id)
|
pipeline_id = str(pipeline_id)
|
||||||
@ -27,7 +25,7 @@ def get_rag_pipeline(
|
|||||||
|
|
||||||
pipeline = (
|
pipeline = (
|
||||||
db.session.query(Pipeline)
|
db.session.query(Pipeline)
|
||||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -26,9 +26,15 @@ from services.errors.audio import (
|
|||||||
UnsupportedAudioTypeServiceError,
|
UnsupportedAudioTypeServiceError,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .. import console_ns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/audio-to-text",
|
||||||
|
endpoint="installed_app_audio",
|
||||||
|
)
|
||||||
class ChatAudioApi(InstalledAppResource):
|
class ChatAudioApi(InstalledAppResource):
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
@ -65,17 +71,23 @@ class ChatAudioApi(InstalledAppResource):
|
|||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/text-to-audio",
|
||||||
|
endpoint="installed_app_text",
|
||||||
|
)
|
||||||
class ChatTextApi(InstalledAppResource):
|
class ChatTextApi(InstalledAppResource):
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
from flask_restx import reqparse
|
from flask_restx import reqparse
|
||||||
|
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
try:
|
try:
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("voice", type=str, location="json")
|
.add_argument("message_id", type=str, required=False, location="json")
|
||||||
parser.add_argument("text", type=str, location="json")
|
.add_argument("voice", type=str, location="json")
|
||||||
parser.add_argument("streaming", type=bool, location="json")
|
.add_argument("text", type=str, location="json")
|
||||||
|
.add_argument("streaming", type=bool, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
message_id = args.get("message_id", None)
|
message_id = args.get("message_id", None)
|
||||||
|
|||||||
@ -33,22 +33,30 @@ from models.model import AppMode
|
|||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
|
||||||
|
from .. import console_ns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# define completion api for user
|
# define completion api for user
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/completion-messages",
|
||||||
|
endpoint="installed_app_completion",
|
||||||
|
)
|
||||||
class CompletionApi(InstalledAppResource):
|
class CompletionApi(InstalledAppResource):
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("query", type=str, location="json", default="")
|
.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("query", type=str, location="json", default="")
|
||||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||||
|
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
streaming = args["response_mode"] == "streaming"
|
streaming = args["response_mode"] == "streaming"
|
||||||
@ -87,6 +95,10 @@ class CompletionApi(InstalledAppResource):
|
|||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
||||||
|
endpoint="installed_app_stop_completion",
|
||||||
|
)
|
||||||
class CompletionStopApi(InstalledAppResource):
|
class CompletionStopApi(InstalledAppResource):
|
||||||
def post(self, installed_app, task_id):
|
def post(self, installed_app, task_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
@ -100,6 +112,10 @@ class CompletionStopApi(InstalledAppResource):
|
|||||||
return {"result": "success"}, 200
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/chat-messages",
|
||||||
|
endpoint="installed_app_chat_completion",
|
||||||
|
)
|
||||||
class ChatApi(InstalledAppResource):
|
class ChatApi(InstalledAppResource):
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
@ -107,13 +123,15 @@ class ChatApi(InstalledAppResource):
|
|||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("query", type=str, required=True, location="json")
|
.add_argument("inputs", type=dict, required=True, location="json")
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("query", type=str, required=True, location="json")
|
||||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||||
|
.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args["auto_generate_name"] = False
|
args["auto_generate_name"] = False
|
||||||
@ -153,6 +171,10 @@ class ChatApi(InstalledAppResource):
|
|||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
||||||
|
endpoint="installed_app_stop_chat_completion",
|
||||||
|
)
|
||||||
class ChatStopApi(InstalledAppResource):
|
class ChatStopApi(InstalledAppResource):
|
||||||
def post(self, installed_app, task_id):
|
def post(self, installed_app, task_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|||||||
@ -16,7 +16,13 @@ from services.conversation_service import ConversationService
|
|||||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||||
from services.web_conversation_service import WebConversationService
|
from services.web_conversation_service import WebConversationService
|
||||||
|
|
||||||
|
from .. import console_ns
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/conversations",
|
||||||
|
endpoint="installed_app_conversations",
|
||||||
|
)
|
||||||
class ConversationListApi(InstalledAppResource):
|
class ConversationListApi(InstalledAppResource):
|
||||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
@ -25,10 +31,12 @@ class ConversationListApi(InstalledAppResource):
|
|||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("last_id", type=uuid_value, location="args")
|
||||||
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
.add_argument("pinned", type=str, choices=["true", "false", None], location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
pinned = None
|
pinned = None
|
||||||
@ -52,6 +60,10 @@ class ConversationListApi(InstalledAppResource):
|
|||||||
raise NotFound("Last Conversation Not Exists.")
|
raise NotFound("Last Conversation Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
||||||
|
endpoint="installed_app_conversation",
|
||||||
|
)
|
||||||
class ConversationApi(InstalledAppResource):
|
class ConversationApi(InstalledAppResource):
|
||||||
def delete(self, installed_app, c_id):
|
def delete(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
@ -70,6 +82,10 @@ class ConversationApi(InstalledAppResource):
|
|||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
||||||
|
endpoint="installed_app_conversation_rename",
|
||||||
|
)
|
||||||
class ConversationRenameApi(InstalledAppResource):
|
class ConversationRenameApi(InstalledAppResource):
|
||||||
@marshal_with(simple_conversation_fields)
|
@marshal_with(simple_conversation_fields)
|
||||||
def post(self, installed_app, c_id):
|
def post(self, installed_app, c_id):
|
||||||
@ -80,9 +96,11 @@ class ConversationRenameApi(InstalledAppResource):
|
|||||||
|
|
||||||
conversation_id = str(c_id)
|
conversation_id = str(c_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("name", type=str, required=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
.add_argument("name", type=str, required=False, location="json")
|
||||||
|
.add_argument("auto_generate", type=bool, required=False, default=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -95,6 +113,10 @@ class ConversationRenameApi(InstalledAppResource):
|
|||||||
raise NotFound("Conversation Not Exists.")
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
||||||
|
endpoint="installed_app_conversation_pin",
|
||||||
|
)
|
||||||
class ConversationPinApi(InstalledAppResource):
|
class ConversationPinApi(InstalledAppResource):
|
||||||
def patch(self, installed_app, c_id):
|
def patch(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
@ -114,6 +136,10 @@ class ConversationPinApi(InstalledAppResource):
|
|||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
||||||
|
endpoint="installed_app_conversation_unpin",
|
||||||
|
)
|
||||||
class ConversationUnPinApi(InstalledAppResource):
|
class ConversationUnPinApi(InstalledAppResource):
|
||||||
def patch(self, installed_app, c_id):
|
def patch(self, installed_app, c_id):
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|||||||
@ -6,31 +6,29 @@ from flask_restx import Resource, inputs, marshal_with, reqparse
|
|||||||
from sqlalchemy import and_, select
|
from sqlalchemy import and_, select
|
||||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.installed_app_fields import installed_app_list_fields
|
from fields.installed_app_fields import installed_app_list_fields
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account, App, InstalledApp, RecommendedApp
|
from models import App, InstalledApp, RecommendedApp
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
from services.app_service import AppService
|
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/installed-apps")
|
||||||
class InstalledAppsListApi(Resource):
|
class InstalledAppsListApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(installed_app_list_fields)
|
@marshal_with(installed_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
app_id = request.args.get("app_id", default=None, type=str)
|
app_id = request.args.get("app_id", default=None, type=str)
|
||||||
if not isinstance(current_user, Account):
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
current_tenant_id = current_user.current_tenant_id
|
|
||||||
|
|
||||||
if app_id:
|
if app_id:
|
||||||
installed_apps = db.session.scalars(
|
installed_apps = db.session.scalars(
|
||||||
@ -68,31 +66,26 @@ class InstalledAppsListApi(Resource):
|
|||||||
|
|
||||||
# Pre-filter out apps without setting or with sso_verified
|
# Pre-filter out apps without setting or with sso_verified
|
||||||
filtered_installed_apps = []
|
filtered_installed_apps = []
|
||||||
app_id_to_app_code = {}
|
|
||||||
|
|
||||||
for installed_app in installed_app_list:
|
for installed_app in installed_app_list:
|
||||||
app_id = installed_app["app"].id
|
app_id = installed_app["app"].id
|
||||||
webapp_setting = webapp_settings.get(app_id)
|
webapp_setting = webapp_settings.get(app_id)
|
||||||
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
|
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
|
||||||
continue
|
continue
|
||||||
app_code = AppService.get_app_code_by_id(str(app_id))
|
|
||||||
app_id_to_app_code[app_id] = app_code
|
|
||||||
filtered_installed_apps.append(installed_app)
|
filtered_installed_apps.append(installed_app)
|
||||||
|
|
||||||
app_codes = list(app_id_to_app_code.values())
|
|
||||||
|
|
||||||
# Batch permission check
|
# Batch permission check
|
||||||
|
app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps]
|
||||||
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
|
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
app_codes=app_codes,
|
app_ids=app_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keep only allowed apps
|
# Keep only allowed apps
|
||||||
res = []
|
res = []
|
||||||
for installed_app in filtered_installed_apps:
|
for installed_app in filtered_installed_apps:
|
||||||
app_id = installed_app["app"].id
|
app_id = installed_app["app"].id
|
||||||
app_code = app_id_to_app_code[app_id]
|
if permissions.get(app_id):
|
||||||
if permissions.get(app_code):
|
|
||||||
res.append(installed_app)
|
res.append(installed_app)
|
||||||
|
|
||||||
installed_app_list = res
|
installed_app_list = res
|
||||||
@ -112,17 +105,15 @@ class InstalledAppsListApi(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@cloud_edition_billing_resource_check("apps")
|
@cloud_edition_billing_resource_check("apps")
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
||||||
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
|
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
|
||||||
if recommended_app is None:
|
if recommended_app is None:
|
||||||
raise NotFound("App not found")
|
raise NotFound("App not found")
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
current_tenant_id = current_user.current_tenant_id
|
|
||||||
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
app = db.session.query(App).where(App.id == args["app_id"]).first()
|
||||||
|
|
||||||
if app is None:
|
if app is None:
|
||||||
@ -154,6 +145,7 @@ class InstalledAppsListApi(Resource):
|
|||||||
return {"message": "App installed successfully"}
|
return {"message": "App installed successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>")
|
||||||
class InstalledAppApi(InstalledAppResource):
|
class InstalledAppApi(InstalledAppResource):
|
||||||
"""
|
"""
|
||||||
update and delete an installed app
|
update and delete an installed app
|
||||||
@ -161,9 +153,8 @@ class InstalledAppApi(InstalledAppResource):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def delete(self, installed_app):
|
def delete(self, installed_app):
|
||||||
if not isinstance(current_user, Account):
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
raise ValueError("current_user must be an Account instance")
|
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||||
if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
|
|
||||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||||
|
|
||||||
db.session.delete(installed_app)
|
db.session.delete(installed_app)
|
||||||
@ -172,8 +163,7 @@ class InstalledAppApi(InstalledAppResource):
|
|||||||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||||
|
|
||||||
def patch(self, installed_app):
|
def patch(self, installed_app):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("is_pinned", type=inputs.boolean)
|
||||||
parser.add_argument("is_pinned", type=inputs.boolean)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
commit_args = False
|
commit_args = False
|
||||||
@ -185,7 +175,3 @@ class InstalledAppApi(InstalledAppResource):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {"result": "success", "message": "App info updated successfully"}
|
return {"result": "success", "message": "App info updated successfully"}
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(InstalledAppsListApi, "/installed-apps")
|
|
||||||
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>")
|
|
||||||
|
|||||||
@ -23,8 +23,7 @@ from core.model_runtime.errors.invoke import InvokeError
|
|||||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models import Account
|
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
@ -36,27 +35,34 @@ from services.errors.message import (
|
|||||||
)
|
)
|
||||||
from services.message_service import MessageService
|
from services.message_service import MessageService
|
||||||
|
|
||||||
|
from .. import console_ns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/messages",
|
||||||
|
endpoint="installed_app_messages",
|
||||||
|
)
|
||||||
class MessageListApi(InstalledAppResource):
|
class MessageListApi(InstalledAppResource):
|
||||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
raise NotChatAppError()
|
raise NotChatAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("first_id", type=uuid_value, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
return MessageService.pagination_by_first_id(
|
return MessageService.pagination_by_first_id(
|
||||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||||
)
|
)
|
||||||
@ -66,20 +72,25 @@ class MessageListApi(InstalledAppResource):
|
|||||||
raise NotFound("First Message Not Exists.")
|
raise NotFound("First Message Not Exists.")
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
||||||
|
endpoint="installed_app_message_feedback",
|
||||||
|
)
|
||||||
class MessageFeedbackApi(InstalledAppResource):
|
class MessageFeedbackApi(InstalledAppResource):
|
||||||
def post(self, installed_app, message_id):
|
def post(self, installed_app, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("content", type=str, location="json")
|
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||||
|
.add_argument("content", type=str, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
MessageService.create_feedback(
|
MessageService.create_feedback(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
@ -93,16 +104,20 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
||||||
|
endpoint="installed_app_more_like_this",
|
||||||
|
)
|
||||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||||
def get(self, installed_app, message_id):
|
def get(self, installed_app, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument(
|
||||||
parser.add_argument(
|
|
||||||
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -110,8 +125,6 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||||||
streaming = args["response_mode"] == "streaming"
|
streaming = args["response_mode"] == "streaming"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
response = AppGenerateService.generate_more_like_this(
|
response = AppGenerateService.generate_more_like_this(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
user=current_user,
|
user=current_user,
|
||||||
@ -139,8 +152,13 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||||
|
endpoint="installed_app_suggested_question",
|
||||||
|
)
|
||||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||||
def get(self, installed_app, message_id):
|
def get(self, installed_app, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||||
@ -149,8 +167,6 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
questions = MessageService.get_suggested_questions_after_answer(
|
questions = MessageService.get_suggested_questions_after_answer(
|
||||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from flask_restx import marshal_with
|
from flask_restx import marshal_with
|
||||||
|
|
||||||
from controllers.common import fields
|
from controllers.common import fields
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.error import AppUnavailableError
|
from controllers.console.app.error import AppUnavailableError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
@ -9,6 +9,7 @@ from models.model import AppMode, InstalledApp
|
|||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters")
|
||||||
class AppParameterApi(InstalledAppResource):
|
class AppParameterApi(InstalledAppResource):
|
||||||
"""Resource for app variables."""
|
"""Resource for app variables."""
|
||||||
|
|
||||||
@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
|
|||||||
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
||||||
class ExploreAppMetaApi(InstalledAppResource):
|
class ExploreAppMetaApi(InstalledAppResource):
|
||||||
def get(self, installed_app: InstalledApp):
|
def get(self, installed_app: InstalledApp):
|
||||||
"""Get app meta"""
|
"""Get app meta"""
|
||||||
@ -46,9 +48,3 @@ class ExploreAppMetaApi(InstalledAppResource):
|
|||||||
if not app_model:
|
if not app_model:
|
||||||
raise ValueError("App not found")
|
raise ValueError("App not found")
|
||||||
return AppService().get_app_meta(app_model)
|
return AppService().get_app_meta(app_model)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters"
|
|
||||||
)
|
|
||||||
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
|
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from libs.helper import AppIconUrlField
|
from libs.helper import AppIconUrlField
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
@ -35,14 +35,14 @@ recommended_app_list_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/explore/apps")
|
||||||
class RecommendedAppListApi(Resource):
|
class RecommendedAppListApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(recommended_app_list_fields)
|
@marshal_with(recommended_app_list_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
# language args
|
# language args
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("language", type=str, location="args")
|
||||||
parser.add_argument("language", type=str, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
language = args.get("language")
|
language = args.get("language")
|
||||||
@ -56,13 +56,10 @@ class RecommendedAppListApi(Resource):
|
|||||||
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
|
return RecommendedAppService.get_recommended_apps_and_categories(language_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/explore/apps/<uuid:app_id>")
|
||||||
class RecommendedAppApi(Resource):
|
class RecommendedAppApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self, app_id):
|
def get(self, app_id):
|
||||||
app_id = str(app_id)
|
app_id = str(app_id)
|
||||||
return RecommendedAppService.get_recommend_app_detail(app_id)
|
return RecommendedAppService.get_recommend_app_detail(app_id)
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(RecommendedAppListApi, "/explore/apps")
|
|
||||||
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>")
|
|
||||||
|
|||||||
@ -2,13 +2,12 @@ from flask_restx import fields, marshal_with, reqparse
|
|||||||
from flask_restx.inputs import int_range
|
from flask_restx.inputs import int_range
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console import api
|
from controllers.console import console_ns
|
||||||
from controllers.console.explore.error import NotCompletionAppError
|
from controllers.console.explore.error import NotCompletionAppError
|
||||||
from controllers.console.explore.wraps import InstalledAppResource
|
from controllers.console.explore.wraps import InstalledAppResource
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models import Account
|
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
from services.saved_message_service import SavedMessageService
|
from services.saved_message_service import SavedMessageService
|
||||||
|
|
||||||
@ -25,6 +24,7 @@ message_fields = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||||
class SavedMessageListApi(InstalledAppResource):
|
class SavedMessageListApi(InstalledAppResource):
|
||||||
saved_message_infinite_scroll_pagination_fields = {
|
saved_message_infinite_scroll_pagination_fields = {
|
||||||
"limit": fields.Integer,
|
"limit": fields.Integer,
|
||||||
@ -34,31 +34,30 @@ class SavedMessageListApi(InstalledAppResource):
|
|||||||
|
|
||||||
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
@marshal_with(saved_message_infinite_scroll_pagination_fields)
|
||||||
def get(self, installed_app):
|
def get(self, installed_app):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
.add_argument("last_id", type=uuid_value, location="args")
|
||||||
|
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"])
|
||||||
|
|
||||||
def post(self, installed_app):
|
def post(self, installed_app):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("message_id", type=uuid_value, required=True, location="json")
|
||||||
parser.add_argument("message_id", type=uuid_value, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
SavedMessageService.save(app_model, current_user, args["message_id"])
|
SavedMessageService.save(app_model, current_user, args["message_id"])
|
||||||
except MessageNotExistsError:
|
except MessageNotExistsError:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
@ -66,8 +65,12 @@ class SavedMessageListApi(InstalledAppResource):
|
|||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route(
|
||||||
|
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
|
||||||
|
)
|
||||||
class SavedMessageApi(InstalledAppResource):
|
class SavedMessageApi(InstalledAppResource):
|
||||||
def delete(self, installed_app, message_id):
|
def delete(self, installed_app, message_id):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
|
|
||||||
message_id = str(message_id)
|
message_id = str(message_id)
|
||||||
@ -75,20 +78,6 @@ class SavedMessageApi(InstalledAppResource):
|
|||||||
if app_model.mode != "completion":
|
if app_model.mode != "completion":
|
||||||
raise NotCompletionAppError()
|
raise NotCompletionAppError()
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("current_user must be an Account instance")
|
|
||||||
SavedMessageService.delete(app_model, current_user, message_id)
|
SavedMessageService.delete(app_model, current_user, message_id)
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(
|
|
||||||
SavedMessageListApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/saved-messages",
|
|
||||||
endpoint="installed_app_saved_messages",
|
|
||||||
)
|
|
||||||
api.add_resource(
|
|
||||||
SavedMessageApi,
|
|
||||||
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
|
|
||||||
endpoint="installed_app_saved_message",
|
|
||||||
)
|
|
||||||
|
|||||||
@ -22,19 +22,23 @@ from core.errors.error import (
|
|||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.login import current_user
|
from libs.login import current_account_with_tenant
|
||||||
from models.model import AppMode, InstalledApp
|
from models.model import AppMode, InstalledApp
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.llm import InvokeRateLimitError
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
|
||||||
|
from .. import console_ns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||||
def post(self, installed_app: InstalledApp):
|
def post(self, installed_app: InstalledApp):
|
||||||
"""
|
"""
|
||||||
Run workflow
|
Run workflow
|
||||||
"""
|
"""
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
app_model = installed_app.app
|
app_model = installed_app.app
|
||||||
if not app_model:
|
if not app_model:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
@ -42,11 +46,12 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
.add_argument("files", type=list, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert current_user is not None
|
|
||||||
try:
|
try:
|
||||||
response = AppGenerateService.generate(
|
response = AppGenerateService.generate(
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||||
@ -70,6 +75,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
|||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop")
|
||||||
class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||||
def post(self, installed_app: InstalledApp, task_id: str):
|
def post(self, installed_app: InstalledApp, task_id: str):
|
||||||
"""
|
"""
|
||||||
@ -81,7 +87,6 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
|||||||
app_mode = AppMode.value_of(app_model.mode)
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
if app_mode != AppMode.WORKFLOW:
|
if app_mode != AppMode.WORKFLOW:
|
||||||
raise NotWorkflowAppError()
|
raise NotWorkflowAppError()
|
||||||
assert current_user is not None
|
|
||||||
|
|
||||||
# Stop using both mechanisms for backward compatibility
|
# Stop using both mechanisms for backward compatibility
|
||||||
# Legacy stop flag mechanism (without user check)
|
# Legacy stop flag mechanism (without user check)
|
||||||
|
|||||||
@ -2,16 +2,14 @@ from collections.abc import Callable
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Concatenate, ParamSpec, TypeVar
|
from typing import Concatenate, ParamSpec, TypeVar
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from controllers.console.explore.error import AppAccessDeniedError
|
from controllers.console.explore.error import AppAccessDeniedError
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import InstalledApp
|
from models import InstalledApp
|
||||||
from services.app_service import AppService
|
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
@ -24,11 +22,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
|||||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||||
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
installed_app = (
|
installed_app = (
|
||||||
db.session.query(InstalledApp)
|
db.session.query(InstalledApp)
|
||||||
.where(
|
.where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id)
|
||||||
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
|
|
||||||
)
|
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -54,13 +51,13 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
|||||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
feature = FeatureService.get_system_features()
|
feature = FeatureService.get_system_features()
|
||||||
if feature.webapp_auth.enabled:
|
if feature.webapp_auth.enabled:
|
||||||
app_id = installed_app.app_id
|
app_id = installed_app.app_id
|
||||||
app_code = AppService.get_app_code_by_id(app_id)
|
|
||||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||||
user_id=str(current_user.id),
|
user_id=str(current_user.id),
|
||||||
app_code=app_code,
|
app_id=app_id,
|
||||||
)
|
)
|
||||||
if not res:
|
if not res:
|
||||||
raise AppAccessDeniedError()
|
raise AppAccessDeniedError()
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from controllers.console import api, console_ns
|
from controllers.console import api, console_ns
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from fields.api_based_extension_fields import api_based_extension_fields
|
from fields.api_based_extension_fields import api_based_extension_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models.api_based_extension import APIBasedExtension
|
from models.api_based_extension import APIBasedExtension
|
||||||
from services.api_based_extension_service import APIBasedExtensionService
|
from services.api_based_extension_service import APIBasedExtensionService
|
||||||
from services.code_based_extension_service import CodeBasedExtensionService
|
from services.code_based_extension_service import CodeBasedExtensionService
|
||||||
@ -30,8 +29,7 @@ class CodeBasedExtensionAPI(Resource):
|
|||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("module", type=str, required=True, location="args")
|
||||||
parser.add_argument("module", type=str, required=True, location="args")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])}
|
||||||
@ -47,7 +45,7 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def get(self):
|
def get(self):
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||||
|
|
||||||
@api.doc("create_api_based_extension")
|
@api.doc("create_api_based_extension")
|
||||||
@ -68,14 +66,11 @@ class APIBasedExtensionAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
args = api.payload
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
|
||||||
parser.add_argument("api_key", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
extension_data = APIBasedExtension(
|
extension_data = APIBasedExtension(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_tenant_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
api_endpoint=args["api_endpoint"],
|
api_endpoint=args["api_endpoint"],
|
||||||
api_key=args["api_key"],
|
api_key=args["api_key"],
|
||||||
@ -96,7 +91,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def get(self, id):
|
def get(self, id):
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
_, tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
@ -120,15 +115,11 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
@marshal_with(api_based_extension_fields)
|
@marshal_with(api_based_extension_fields)
|
||||||
def post(self, id):
|
def post(self, id):
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
args = api.payload
|
||||||
parser.add_argument("name", type=str, required=True, location="json")
|
|
||||||
parser.add_argument("api_endpoint", type=str, required=True, location="json")
|
|
||||||
parser.add_argument("api_key", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
extension_data_from_db.name = args["name"]
|
extension_data_from_db.name = args["name"]
|
||||||
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
||||||
@ -147,9 +138,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
|||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def delete(self, id):
|
def delete(self, id):
|
||||||
api_based_extension_id = str(id)
|
api_based_extension_id = str(id)
|
||||||
tenant_id = current_user.current_tenant_id
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
APIBasedExtensionService.delete(extension_data_from_db)
|
APIBasedExtensionService.delete(extension_data_from_db)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, fields
|
from flask_restx import Resource, fields
|
||||||
|
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
from . import api, console_ns
|
from . import api, console_ns
|
||||||
@ -23,7 +22,9 @@ class FeatureApi(Resource):
|
|||||||
@cloud_utm_record
|
@cloud_utm_record
|
||||||
def get(self):
|
def get(self):
|
||||||
"""Get feature configuration for current tenant"""
|
"""Get feature configuration for current tenant"""
|
||||||
return FeatureService.get_features(current_user.current_tenant_id).model_dump()
|
_, current_tenant_id = current_account_with_tenant()
|
||||||
|
|
||||||
|
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/system-features")
|
@console_ns.route("/system-features")
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with
|
from flask_restx import Resource, marshal_with
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
@ -9,6 +8,7 @@ import services
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants import DOCUMENT_EXTENSIONS
|
from constants import DOCUMENT_EXTENSIONS
|
||||||
from controllers.common.errors import (
|
from controllers.common.errors import (
|
||||||
|
BlockedFileExtensionError,
|
||||||
FilenameNotExistsError,
|
FilenameNotExistsError,
|
||||||
FileTooLargeError,
|
FileTooLargeError,
|
||||||
NoFileUploadedError,
|
NoFileUploadedError,
|
||||||
@ -22,13 +22,15 @@ from controllers.console.wraps import (
|
|||||||
)
|
)
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields, upload_config_fields
|
from fields.file_fields import file_fields, upload_config_fields
|
||||||
from libs.login import login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
from models import Account
|
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
from . import console_ns
|
||||||
|
|
||||||
PREVIEW_WORDS_LIMIT = 3000
|
PREVIEW_WORDS_LIMIT = 3000
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/files/upload")
|
||||||
class FileApi(Resource):
|
class FileApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -38,6 +40,7 @@ class FileApi(Resource):
|
|||||||
return {
|
return {
|
||||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||||
|
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
|
||||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||||
@ -50,6 +53,7 @@ class FileApi(Resource):
|
|||||||
@marshal_with(file_fields)
|
@marshal_with(file_fields)
|
||||||
@cloud_edition_billing_resource_check("documents")
|
@cloud_edition_billing_resource_check("documents")
|
||||||
def post(self):
|
def post(self):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
source_str = request.form.get("source")
|
source_str = request.form.get("source")
|
||||||
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
||||||
|
|
||||||
@ -62,16 +66,12 @@ class FileApi(Resource):
|
|||||||
|
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
raise FilenameNotExistsError
|
raise FilenameNotExistsError
|
||||||
|
|
||||||
if source == "datasets" and not current_user.is_dataset_editor:
|
if source == "datasets" and not current_user.is_dataset_editor:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
|
|
||||||
if source not in ("datasets", None):
|
if source not in ("datasets", None):
|
||||||
source = None
|
source = None
|
||||||
|
|
||||||
if not isinstance(current_user, Account):
|
|
||||||
raise ValueError("Invalid user account")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upload_file = FileService(db.engine).upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
@ -84,10 +84,13 @@ class FileApi(Resource):
|
|||||||
raise FileTooLargeError(file_too_large_error.description)
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
|
||||||
|
raise BlockedFileExtensionError(blocked_extension_error.description)
|
||||||
|
|
||||||
return upload_file, 201
|
return upload_file, 201
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/files/<uuid:file_id>/preview")
|
||||||
class FilePreviewApi(Resource):
|
class FilePreviewApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -98,9 +101,10 @@ class FilePreviewApi(Resource):
|
|||||||
return {"content": text}
|
return {"content": text}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/files/support-type")
|
||||||
class FileSupportTypeApi(Resource):
|
class FileSupportTypeApi(Resource):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}
|
||||||
|
|||||||
@ -57,8 +57,7 @@ class InitValidateAPI(Resource):
|
|||||||
if tenant_count > 0:
|
if tenant_count > 0:
|
||||||
raise AlreadySetupError()
|
raise AlreadySetupError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("password", type=StrLen(30), required=True, location="json")
|
||||||
parser.add_argument("password", type=StrLen(30), required=True, location="json")
|
|
||||||
input_password = parser.parse_args()["password"]
|
input_password = parser.parse_args()["password"]
|
||||||
|
|
||||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restx import Resource, marshal_with, reqparse
|
from flask_restx import Resource, marshal_with, reqparse
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -16,10 +14,13 @@ from core.file import helpers as file_helpers
|
|||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||||
from models.account import Account
|
from libs.login import current_account_with_tenant
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
from . import console_ns
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/remote-files/<path:url>")
|
||||||
class RemoteFileInfoApi(Resource):
|
class RemoteFileInfoApi(Resource):
|
||||||
@marshal_with(remote_file_info_fields)
|
@marshal_with(remote_file_info_fields)
|
||||||
def get(self, url):
|
def get(self, url):
|
||||||
@ -35,11 +36,11 @@ class RemoteFileInfoApi(Resource):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/remote-files/upload")
|
||||||
class RemoteFileUploadApi(Resource):
|
class RemoteFileUploadApi(Resource):
|
||||||
@marshal_with(file_fields_with_signed_url)
|
@marshal_with(file_fields_with_signed_url)
|
||||||
def post(self):
|
def post(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser().add_argument("url", type=str, required=True, help="URL is required")
|
||||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
url = args["url"]
|
url = args["url"]
|
||||||
@ -61,7 +62,7 @@ class RemoteFileUploadApi(Resource):
|
|||||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = cast(Account, current_user)
|
user, _ = current_account_with_tenant()
|
||||||
upload_file = FileService(db.engine).upload_file(
|
upload_file = FileService(db.engine).upload_file(
|
||||||
filename=file_info.filename,
|
filename=file_info.filename,
|
||||||
content=content,
|
content=content,
|
||||||
|
|||||||
@ -69,15 +69,22 @@ class SetupApi(Resource):
|
|||||||
if not get_init_validate_status():
|
if not get_init_validate_status():
|
||||||
raise NotInitValidateError()
|
raise NotInitValidateError()
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
parser = (
|
||||||
parser.add_argument("email", type=email, required=True, location="json")
|
reqparse.RequestParser()
|
||||||
parser.add_argument("name", type=StrLen(30), required=True, location="json")
|
.add_argument("email", type=email, required=True, location="json")
|
||||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||||
|
.add_argument("password", type=valid_password, required=True, location="json")
|
||||||
|
.add_argument("language", type=str, required=False, location="json")
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# setup
|
# setup
|
||||||
RegisterService.setup(
|
RegisterService.setup(
|
||||||
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
|
email=args["email"],
|
||||||
|
name=args["name"],
|
||||||
|
password=args["password"],
|
||||||
|
ip_address=extract_remote_ip(request),
|
||||||
|
language=args["language"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"result": "success"}, 201
|
return {"result": "success"}, 201
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user