mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 19:55:06 +08:00
Compare commits
595 Commits
purity
...
copilot/re
| Author | SHA1 | Date | |
|---|---|---|---|
| f5528f2030 | |||
| 6efdc94661 | |||
| 68526c09fc | |||
| a78bc507c0 | |||
| e83c7438cb | |||
| 82068a6918 | |||
| 108bcbeb7c | |||
| c4b02be6d3 | |||
| 30eebf804f | |||
| ad7fdd18d0 | |||
| 5d2fbf5215 | |||
| 4a89403566 | |||
| e0c05b2123 | |||
| 85b99580ea | |||
| 15fbedfcad | |||
| 1e6d0de48b | |||
| cad751c00c | |||
| a47276ac24 | |||
| 20403c69b2 | |||
| ffc04f2a9b | |||
| d1580791e4 | |||
| c74eb4fcf3 | |||
| a798534337 | |||
| 470883858e | |||
| 4f4911686d | |||
| 6d479dcdbb | |||
| 24348c40a6 | |||
| a39b50adbb | |||
| 81832c14ee | |||
| b86022c64a | |||
| 45e816a9f6 | |||
| 667b1c37a3 | |||
| b75d533f9b | |||
| aece55d82f | |||
| c432b398f4 | |||
| 9cb2645793 | |||
| 6ac61bd585 | |||
| b02165ffe6 | |||
| 6c576e2c66 | |||
| b0e7e7752f | |||
| 2799b79e8c | |||
| 805a1479f9 | |||
| fe6538b08d | |||
| 1bbb9d6644 | |||
| 5c06e285ec | |||
| 19c92fd670 | |||
| 6026bd873b | |||
| 1369119a0c | |||
| b76e17b25d | |||
| ca7794305b | |||
| fd255e81e1 | |||
| 09d31d1263 | |||
| 47dc26f011 | |||
| 123bb3ec08 | |||
| 90f77282e3 | |||
| 5208867ccc | |||
| edc7ccc795 | |||
| c9798f6425 | |||
| 20ecf7f1d0 | |||
| 9dcb780fcb | |||
| 1cb7b09933 | |||
| 2c62a77cf4 | |||
| b9bc48d8dd | |||
| ed234e311b | |||
| 9843fec393 | |||
| aa4cabdeb5 | |||
| eea713b668 | |||
| fc62538a94 | |||
| 7994144df7 | |||
| e153c483b6 | |||
| 422bb4d4bb | |||
| 87a80d7613 | |||
| e91105ca87 | |||
| 37903722fe | |||
| f4c82d0010 | |||
| fe50093c18 | |||
| 4317af1e90 | |||
| 61a0fcc2ea | |||
| f627348b11 | |||
| 87fb9a6b69 | |||
| 97a2e2ec2e | |||
| 68d357d7f6 | |||
| a103ad3ee7 | |||
| f65d5a9761 | |||
| 6e0a5f5bbd | |||
| 22f858152f | |||
| 775d2e14fc | |||
| 744b287e67 | |||
| c0fc5d98f0 | |||
| 08ea79d730 | |||
| f31b821cc0 | |||
| 34be16874f | |||
| e9738b891f | |||
| 829796514a | |||
| ef1db35f80 | |||
| f9c67621ca | |||
| e29e8e3180 | |||
| 7a81e720d4 | |||
| 55600c0eb1 | |||
| 35e41d7d68 | |||
| b610cf9a11 | |||
| c8e9edc024 | |||
| 471cd760d7 | |||
| 7f48c57edf | |||
| 6569801162 | |||
| 9dd83f50a7 | |||
| 59c56b1b0d | |||
| 94cd2de940 | |||
| 3c23375607 | |||
| 56047f638f | |||
| 9c01d3e775 | |||
| c85c87f3da | |||
| eaa02e3d55 | |||
| 0219222a60 | |||
| dba659b220 | |||
| ee6458768e | |||
| ed3d02dc6d | |||
| 95471b1188 | |||
| 6190cfbfd8 | |||
| 11f2f95103 | |||
| 2abbc14703 | |||
| b2b2816ade | |||
| 4461df1bd9 | |||
| f7f6b4a8b0 | |||
| 41be581594 | |||
| 20ad5b7ac2 | |||
| a1c0bd7a1c | |||
| fd7c4e8a6d | |||
| 41e549af14 | |||
| b7360140ee | |||
| c71f7c7613 | |||
| c905c47775 | |||
| 4ca7ba000c | |||
| f260627660 | |||
| 1e9142c213 | |||
| 82890fe38e | |||
| 7dc7c8af98 | |||
| addebc465a | |||
| 5ab315aeaf | |||
| f092bc1912 | |||
| 23b49b8304 | |||
| 9e97248ede | |||
| d532b06310 | |||
| 07a2281730 | |||
| 42385f3ffa | |||
| c597234374 | |||
| 3de73f07c6 | |||
| 0caeaf6e5c | |||
| 3395297c3e | |||
| e60a7c7143 | |||
| 0e62a66cc2 | |||
| ff32dff163 | |||
| 543c5236e7 | |||
| 341b3ae7c9 | |||
| f01907aac2 | |||
| a7c855cab8 | |||
| 29afc0657d | |||
| d9860b8907 | |||
| dc1ae57dc6 | |||
| d6bd2a9bdb | |||
| c9eed67cf6 | |||
| 0ded6303c1 | |||
| b6e0abadab | |||
| 43bcf40f80 | |||
| f06025a342 | |||
| 24fb95b050 | |||
| 49fca63927 | |||
| ce5fe86430 | |||
| 666586b59c | |||
| 8a2851551a | |||
| a2fe4a28c3 | |||
| 417ebd160b | |||
| 82be305680 | |||
| 03002f4971 | |||
| 1e7e8a8988 | |||
| a715d5ac23 | |||
| 398c8117fe | |||
| f45c18ee35 | |||
| 15c1db42dd | |||
| a31c01f8d9 | |||
| 62753cdf13 | |||
| dc7ce125ad | |||
| eabdb09f8e | |||
| fa6d03c979 | |||
| 634fb192ef | |||
| a4b38e7521 | |||
| 8ff6de91b0 | |||
| 7fa0ad3161 | |||
| 53b21eea61 | |||
| 2f3a61b51b | |||
| 8bca7814f4 | |||
| 92c81b1833 | |||
| 44553d412c | |||
| 95ce224df0 | |||
| 8555635967 | |||
| e843fe8aa6 | |||
| b198c9474a | |||
| 4bb00b83d9 | |||
| c91cbf6b97 | |||
| f6ede6f1c1 | |||
| 65976b27fe | |||
| 2d73ee64a3 | |||
| c61c2b0abd | |||
| 40d3332690 | |||
| 8e45753c68 | |||
| 73e217ab0d | |||
| 26ff59172e | |||
| bebb4ffbaa | |||
| 523da66134 | |||
| e1ca7a9bdb | |||
| 9a8cf709ba | |||
| f909040567 | |||
| 845adb664a | |||
| 0c6cae2d59 | |||
| a893ee0ffc | |||
| 82b63cc6e2 | |||
| c327cfa86e | |||
| 82219c1162 | |||
| cfc3f1527a | |||
| caf1a5fbab | |||
| 4a6398fc1f | |||
| 2bcf96565a | |||
| 9a9d6a4a2b | |||
| 05f66fcf0d | |||
| ea8245a91b | |||
| 759a932bb7 | |||
| fb6f05c267 | |||
| ff9b74efeb | |||
| d6e7543ba6 | |||
| e45d5700ec | |||
| 4e6682bd85 | |||
| 32c715c4d0 | |||
| c11cdf7468 | |||
| 6217c96576 | |||
| 977690590e | |||
| fd845c8b6c | |||
| d7d9abb007 | |||
| 9f22b2726b | |||
| f28b519556 | |||
| 762cf91133 | |||
| 9dd3dcff2b | |||
| 34fbcc9457 | |||
| 9cc8ac981b | |||
| 1153dcef69 | |||
| f811471b18 | |||
| 2382229c7d | |||
| f0e739be43 | |||
| 4dccdf9478 | |||
| 4c37d650d3 | |||
| 1b334e6966 | |||
| d463bd6323 | |||
| 8c298b33cd | |||
| dc1a380888 | |||
| 7e9be4d3d9 | |||
| 5579521ffc | |||
| ab1059134d | |||
| fe2ac66a52 | |||
| f87db2652b | |||
| 3f9f02b9e7 | |||
| 578247ffbc | |||
| 9a5f214623 | |||
| 141ca8904a | |||
| 4488c090b2 | |||
| 59c1fde351 | |||
| cf7ff76165 | |||
| ac79691d69 | |||
| 1a37989769 | |||
| 830f891a74 | |||
| 5937a66e22 | |||
| 894e38f713 | |||
| e4b5b0e5fd | |||
| 598dd1f816 | |||
| 35e24d4d14 | |||
| fea2ffb3ba | |||
| 64f55d55a1 | |||
| bfda4ce7e6 | |||
| 4f7cb7cd2a | |||
| 6517323add | |||
| 531a0b755a | |||
| 91bb8ae4d2 | |||
| 8cafc20098 | |||
| 9d5300440c | |||
| 58524d6d2b | |||
| 19cc6ea993 | |||
| d7f0a31e24 | |||
| 312974aa20 | |||
| d19c100166 | |||
| a8ad80c405 | |||
| 650e38e17f | |||
| 24612adf2c | |||
| 06649f6c21 | |||
| 8b61f5e9c4 | |||
| 6432898e7a | |||
| cced33d068 | |||
| bd01af6415 | |||
| 35011b810d | |||
| f295c7532c | |||
| 7065b67d07 | |||
| c0b50ef61d | |||
| 1d8cca4fa2 | |||
| 3474c179e6 | |||
| 433dad7e1a | |||
| be7ee380bc | |||
| cff5de626b | |||
| 4d8b8f9210 | |||
| a16ef7e73c | |||
| c39dae06d4 | |||
| f906e70f6b | |||
| 5139119307 | |||
| 1b537f904a | |||
| 556b631c54 | |||
| 49df9ceaf3 | |||
| 92ec1ac27a | |||
| e74097afdf | |||
| 8ddc4f2292 | |||
| 7b51320346 | |||
| 9e39be0770 | |||
| 3e5e87930c | |||
| 15a5ba67f1 | |||
| 9e3b4dc90d | |||
| 48c42a9fba | |||
| 0b35bc1ede | |||
| 8e01bb40fe | |||
| 9d21772820 | |||
| b745839bdb | |||
| 59ad6e02ce | |||
| a3b33cbe28 | |||
| 7b8540281a | |||
| 0a6b78f883 | |||
| 56ee8f7d64 | |||
| 3cfcd32876 | |||
| 06dcb55a9d | |||
| ec6cafd7aa | |||
| 6e9858960d | |||
| 150a8276b9 | |||
| c6a90d4bb3 | |||
| c71fd7113c | |||
| 5fc104a992 | |||
| d1de3cfb94 | |||
| 44d36f2460 | |||
| 9088f151d9 | |||
| c692962650 | |||
| f0a60a9000 | |||
| 2f50f3fd4b | |||
| 24cd7bbc62 | |||
| d299e75e1b | |||
| f86b6658c9 | |||
| 0a56d65581 | |||
| dfc03bac9f | |||
| 81e1376e08 | |||
| f50c85d536 | |||
| 5830c69694 | |||
| 0173496a77 | |||
| 30c5b47699 | |||
| e3191d4e91 | |||
| a9b3539b90 | |||
| 5217017e69 | |||
| bd5df5cf1c | |||
| 456dbfe7d7 | |||
| 586f210d6e | |||
| 275a0f9ddd | |||
| cbf2ba6cec | |||
| 1bd621f819 | |||
| bb6a331490 | |||
| 3922ad876f | |||
| fdb53fdeb1 | |||
| 3fb5a7bff1 | |||
| 6157c67cfe | |||
| fbc745764a | |||
| 78f09801b5 | |||
| d0dd81cf84 | |||
| 65b832c46c | |||
| a90b60c36f | |||
| 94a07706ec | |||
| ab2eacb6c1 | |||
| aead192743 | |||
| c1e8584b97 | |||
| 8a2b208299 | |||
| 2b6882bd97 | |||
| aa51662d98 | |||
| 3068526797 | |||
| 298d8c2d88 | |||
| 294e01a8c1 | |||
| 3a5aa4587c | |||
| cf1778e696 | |||
| 54db4c176a | |||
| 5d3e8a31d0 | |||
| 885dff82e3 | |||
| 3c4aa24198 | |||
| 33b0814323 | |||
| 45ae511036 | |||
| 0fa063c640 | |||
| 40d35304ea | |||
| 89821d66bb | |||
| 09d84e900c | |||
| a8746bff30 | |||
| c4d8bf0ce9 | |||
| 9cca605bac | |||
| dbd23f91e5 | |||
| 9387cc088c | |||
| 11f7a89e25 | |||
| 654d522b31 | |||
| 31e6ef77a6 | |||
| e56c847210 | |||
| e00172199a | |||
| 04f47836d8 | |||
| faaca822e4 | |||
| dc0f053925 | |||
| 517726da3a | |||
| 1d6c03eddf | |||
| fdfccd1205 | |||
| b30e7ced0a | |||
| 11770439be | |||
| d89c5f7146 | |||
| 4a475bf1cd | |||
| 10be9cfbbf | |||
| c20e0ad90d | |||
| 22f64d60bb | |||
| 7b7d332239 | |||
| b1d189324a | |||
| 00fb468f2e | |||
| bbbb6e04cb | |||
| f5161d9add | |||
| 787251f00e | |||
| cfe21f0826 | |||
| 196f691865 | |||
| 7a5bb1cfac | |||
| b80d55b764 | |||
| dd71625f52 | |||
| 19936d23d1 | |||
| decf0f3da0 | |||
| 7242a67f84 | |||
| c4884eb669 | |||
| d49f3327e4 | |||
| 633e68a2f7 | |||
| 809f48f733 | |||
| 578b1b45ea | |||
| 86c3c58e64 | |||
| 8d803a26eb | |||
| aa3129c2a9 | |||
| 97c924fe29 | |||
| 591c463e4b | |||
| e1691fddaa | |||
| b4d4351203 | |||
| f7b1348623 | |||
| 2619c7553a | |||
| f79d8baf63 | |||
| bbdcbac544 | |||
| d552680e72 | |||
| df43c6ab8a | |||
| cd47a47c3b | |||
| e5d4235f1b | |||
| f60aa36fa0 | |||
| b2bcb6d21a | |||
| b6cea71023 | |||
| 6462328620 | |||
| fd86cadf67 | |||
| c43c72c1a3 | |||
| d77c2e4d17 | |||
| 1a7898dff1 | |||
| af662b100b | |||
| 595df172a8 | |||
| 70bc5ca7f4 | |||
| 30617feff8 | |||
| 756864c85b | |||
| c8c94ef870 | |||
| 10d51ada59 | |||
| 00f3a53f1c | |||
| d2f0551170 | |||
| cba2b9b2ad | |||
| 029d5d36ac | |||
| 8d897153a5 | |||
| 2e914808ea | |||
| d00a72a435 | |||
| 36580221aa | |||
| e686cc9eab | |||
| 66196459d5 | |||
| a5387b304e | |||
| beb1448441 | |||
| 272102c06d | |||
| 36406cd62f | |||
| 87c41c88a3 | |||
| 095c56a646 | |||
| 244c132656 | |||
| 043ec46c33 | |||
| 0e4f19eee0 | |||
| ff34969f21 | |||
| 9a7245e1df | |||
| 4906eeac18 | |||
| 4da93ba579 | |||
| 319ecdd312 | |||
| 0c1ec35244 | |||
| 46375aacdb | |||
| e6d4331994 | |||
| 2a0abc51b1 | |||
| 3bb67885ef | |||
| e682749d03 | |||
| 9b83b0aadd | |||
| 0cac330bc2 | |||
| fb8114792a | |||
| eab6f65409 | |||
| 915023b809 | |||
| f104839672 | |||
| 6841a09667 | |||
| e937c8c72e | |||
| 960bb8a9b4 | |||
| 9b36059292 | |||
| a4acc64afd | |||
| 25c69ac540 | |||
| 96a0b9991e | |||
| 2913d17fe2 | |||
| d9e45a1abe | |||
| 24b4289d6c | |||
| fb6ccccc3d | |||
| 8b74ae683a | |||
| dd08957381 | |||
| 407323f817 | |||
| 2e2c87c5a1 | |||
| f4522fd695 | |||
| 760a2c656c | |||
| 8940decd1b | |||
| 0c4193bd91 | |||
| cd40cde790 | |||
| c60c754ac9 | |||
| ef80d3b707 | |||
| 24e8d21b3f | |||
| d823da18db | |||
| 1e3df09fc6 | |||
| 75a10c276c | |||
| 50050527eb | |||
| a39b185627 | |||
| 15270f09af | |||
| f6a5ac0698 | |||
| 2b79da722b | |||
| 71d69e43cd | |||
| 5bc6e8a433 | |||
| 68076f2e22 | |||
| 8c38363038 | |||
| 345ac8333c | |||
| 2375047ef0 | |||
| 857a48012e | |||
| 208fe3d7de | |||
| 92cddbcc02 | |||
| 599b53c9cb | |||
| 062b173c66 | |||
| db690013fd | |||
| e93bfe3d41 | |||
| ab910c736c | |||
| 4047a6bb12 | |||
| df2478dc26 | |||
| 4cc3f6045b | |||
| 1550316b8d | |||
| 87394d2512 | |||
| bad59c95bc | |||
| 9f138ef246 | |||
| 6453fc4973 | |||
| f62f926537 | |||
| b3dafd913b | |||
| b2d8a7eaf1 | |||
| 3e54414191 | |||
| a173546c8d | |||
| aa69d90489 | |||
| 4ba1292455 | |||
| bb01c31f30 | |||
| cd90b2ca9e | |||
| 9a65350cf7 | |||
| 680eb7a9f6 | |||
| 878420463c | |||
| 4692e20daf | |||
| 13fe2ca8fe | |||
| 1264e7d4f6 | |||
| 4f45978cd9 | |||
| 5a0bf8e028 | |||
| ffa163a8a8 | |||
| 8f86f5749d | |||
| 00d3bf15f3 | |||
| 7196c09e9d | |||
| fadd9e0bf4 | |||
| d8b4bbe067 | |||
| 24611e375a | |||
| ccec582cea | |||
| b2e4107c17 | |||
| 87aa070486 | |||
| 21230a8eb2 | |||
| 85cda47c70 | |||
| 7dadb33003 | |||
| b5a7e64e19 | |||
| b283b10d3e | |||
| ecb22226d6 | |||
| 8635aacb46 | |||
| bdd85b36a4 | |||
| a0c7713494 | |||
| abf4955c26 | |||
| 74340e3c04 | |||
| b98b389baf |
@ -1,4 +1,4 @@
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
"nodeGypDependencies": true,
|
||||
"version": "lts"
|
||||
},
|
||||
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
|
||||
"ghcr.io/devcontainers-extra/features/npm-package:1": {
|
||||
"package": "typescript",
|
||||
"version": "latest"
|
||||
},
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
#!/bin/bash
|
||||
WORKSPACE_ROOT=$(pwd)
|
||||
|
||||
corepack enable
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
|
||||
echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
|
||||
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
||||
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -1,5 +1,8 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: "\U0001F510 Security Vulnerabilities"
|
||||
url: "https://github.com/langgenius/dify/security/advisories/new"
|
||||
about: Report security vulnerabilities through GitHub Security Advisories to ensure responsible disclosure. 💡 Please do not report security vulnerabilities in public issues.
|
||||
- name: "\U0001F4A1 Model Providers & Plugins"
|
||||
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
|
||||
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
|
||||
|
||||
30
.github/workflows/api-tests.yml
vendored
30
.github/workflows/api-tests.yml
vendored
@ -39,25 +39,11 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev pyrefly
|
||||
uv run pyrefly check || true
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
@ -93,3 +79,19 @@ jobs:
|
||||
|
||||
- name: Run TestContainers
|
||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
15
.github/workflows/autofix.yml
vendored
15
.github/workflows/autofix.yml
vendored
@ -2,6 +2,8 @@ name: autofix.ci
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@ -15,19 +17,28 @@ jobs:
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
python-version: "3.11"
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
# fmt first to avoid line too long
|
||||
uv run ruff format ..
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
uv run ruff format .
|
||||
uv run ruff format ..
|
||||
|
||||
- name: count migration progress
|
||||
run: |
|
||||
cd api
|
||||
./cnt_base.sh
|
||||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||
cat > /tmp/optional-rule.yml << 'EOF'
|
||||
id: convert-optional-to-union
|
||||
|
||||
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
@ -4,10 +4,10 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- "deploy/**"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
3
.github/workflows/deploy-dev.yml
vendored
3
.github/workflows/deploy-dev.yml
vendored
@ -12,7 +12,8 @@ jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
name: Deploy RAG Dev
|
||||
name: Deploy Trigger Dev
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@ -7,7 +7,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
- "deploy/trigger-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -16,12 +16,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
host: ${{ secrets.TRIGGER_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||
@ -13,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
|
||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
|
||||
11
.github/workflows/style.yml
vendored
11
.github/workflows/style.yml
vendored
@ -12,7 +12,6 @@ permissions:
|
||||
statuses: write
|
||||
contents: read
|
||||
|
||||
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
@ -44,6 +43,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Import Linter
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --directory api --dev lint-imports
|
||||
|
||||
- name: Run Basedpyright Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: dev/basedpyright-check
|
||||
@ -99,7 +102,11 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpm run lint
|
||||
pnpm run eslint
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
||||
5
.github/workflows/vdb-tests.yml
vendored
5
.github/workflows/vdb-tests.yml
vendored
@ -1,7 +1,10 @@
|
||||
name: Run VDB Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'api/core/rag/*.py'
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
13
.gitignore
vendored
13
.gitignore
vendored
@ -6,6 +6,9 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# *db files
|
||||
*.db
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
@ -97,6 +100,7 @@ __pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat-schedule.db
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
@ -230,4 +234,11 @@ api/.env.backup
|
||||
|
||||
# Benchmark
|
||||
scripts/stress-test/setup/config/
|
||||
scripts/stress-test/reports/
|
||||
scripts/stress-test/reports/
|
||||
|
||||
# mcp
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
|
||||
9
.vscode/launch.json.template
vendored
9
.vscode/launch.json.template
vendored
@ -8,8 +8,7 @@
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
"FLASK_ENV": "development"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
@ -28,9 +27,7 @@
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
@ -40,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace",
|
||||
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
54
AGENTS.md
Normal file
54
AGENTS.md
Normal file
@ -0,0 +1,54 @@
|
||||
# AGENTS.md
|
||||
|
||||
## Project Overview
|
||||
|
||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||
|
||||
The codebase is split into:
|
||||
|
||||
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
|
||||
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Backend Workflow
|
||||
|
||||
- Run backend CLI commands through `uv run --project api <command>`.
|
||||
|
||||
- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||
|
||||
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
|
||||
|
||||
- Integration tests are CI-only and are not expected to run in the local environment.
|
||||
|
||||
## Frontend Workflow
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint
|
||||
pnpm lint:fix
|
||||
pnpm test
|
||||
```
|
||||
|
||||
## Testing & Quality Practices
|
||||
|
||||
- Follow TDD: red → green → refactor.
|
||||
- Use `pytest` for backend tests with Arrange-Act-Assert structure.
|
||||
- Enforce strong typing; avoid `Any` and prefer explicit type annotations.
|
||||
- Write self-documenting code; only add comments that explain intent.
|
||||
|
||||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`).
|
||||
- **TypeScript**: Use the strict config, lean on ESLint + Prettier workflows, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
- Prefer editing existing files; add new documentation only when requested.
|
||||
- Inject dependencies through constructors and preserve clean architecture boundaries.
|
||||
- Handle errors with domain-specific exceptions at the correct layer.
|
||||
|
||||
## Project Conventions
|
||||
|
||||
- Backend architecture adheres to DDD and Clean Architecture principles.
|
||||
- Async work runs through Celery with Redis as the broker.
|
||||
- Frontend user-facing strings must use `web/i18n/en-US/`; avoid hardcoded text.
|
||||
89
CLAUDE.md
89
CLAUDE.md
@ -1,89 +0,0 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
Dify is an open-source platform for developing LLM applications with an intuitive interface combining agentic AI workflows, RAG pipelines, agent capabilities, and model management.
|
||||
|
||||
The codebase consists of:
|
||||
|
||||
- **Backend API** (`/api`): Python Flask application with Domain-Driven Design architecture
|
||||
- **Frontend Web** (`/web`): Next.js 15 application with TypeScript and React 19
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Backend (API)
|
||||
|
||||
All Python commands must be prefixed with `uv run --project api`:
|
||||
|
||||
```bash
|
||||
# Start development servers
|
||||
./dev/start-api # Start API server
|
||||
./dev/start-worker # Start Celery worker
|
||||
|
||||
# Run tests
|
||||
uv run --project api pytest # Run all tests
|
||||
uv run --project api pytest tests/unit_tests/ # Unit tests only
|
||||
uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --directory api basedpyright # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm lint # Run ESLint
|
||||
pnpm eslint-fix # Fix ESLint issues
|
||||
pnpm test # Run Jest tests
|
||||
```
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Backend Testing
|
||||
|
||||
- Use `pytest` for all backend tests
|
||||
- Write tests first (TDD approach)
|
||||
- Test structure: Arrange-Act-Assert
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
### Python
|
||||
|
||||
- Use type hints for all functions and class attributes
|
||||
- No `Any` types unless absolutely necessary
|
||||
- Implement special methods (`__repr__`, `__str__`) appropriately
|
||||
|
||||
### TypeScript/JavaScript
|
||||
|
||||
- Strict TypeScript configuration
|
||||
- ESLint with Prettier integration
|
||||
- Avoid `any` type
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Environment Variables**: Always use UV for Python commands: `uv run --project api <command>`
|
||||
- **Comments**: Only write meaningful comments that explain "why", not "what"
|
||||
- **File Creation**: Always prefer editing existing files over creating new ones
|
||||
- **Documentation**: Don't create documentation files unless explicitly requested
|
||||
- **Code Quality**: Always run `./dev/reformat` before committing backend changes
|
||||
|
||||
## Common Development Tasks
|
||||
|
||||
### Adding a New API Endpoint
|
||||
|
||||
1. Create controller in `/api/controllers/`
|
||||
1. Add service logic in `/api/services/`
|
||||
1. Update routes in controller's `__init__.py`
|
||||
1. Write tests in `/api/tests/`
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
6
Makefile
6
Makefile
@ -26,7 +26,6 @@ prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
@ -61,8 +60,9 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format and check with fixes..."
|
||||
@uv run --directory api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@echo "🔧 Running ruff format, check with fixes, and import linter..."
|
||||
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
|
||||
@uv run --directory api --dev lint-imports
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
|
||||
44
README.md
44
README.md
@ -40,18 +40,18 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
||||
@ -63,7 +63,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
|
||||
> - CPU >= 2 Core
|
||||
> - RAM >= 4 GiB
|
||||
|
||||
</br>
|
||||
<br/>
|
||||
|
||||
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
|
||||
@ -109,15 +109,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
||||
|
||||
## Using Dify
|
||||
|
||||
- **Cloud </br>**
|
||||
- **Cloud <br/>**
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
|
||||
- **Self-hosting Dify Community Edition</br>**
|
||||
- **Self-hosting Dify Community Edition<br/>**
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations</br>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
- **Dify for enterprise / organizations<br/>**
|
||||
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
|
||||
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
@ -129,8 +129,18 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
## Advanced Setup
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
||||
|
||||
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
|
||||
|
||||
### Deployment with Kubernetes
|
||||
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
|
||||
@ -27,6 +27,9 @@ FILES_URL=http://localhost:5001
|
||||
# Example: INTERNAL_FILES_URL=http://api:5001
|
||||
INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# TRIGGER URL
|
||||
TRIGGER_URL=http://localhost:5001
|
||||
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
@ -76,6 +79,7 @@ DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
@ -155,6 +159,8 @@ SUPABASE_URL=your-server-url
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). Leading dots are optional.
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
@ -303,6 +309,8 @@ BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER=DEFAULT_ANALYZER
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE=COARSE_MODE
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
@ -328,7 +336,7 @@ MATRIXONE_DATABASE=dify
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
LINDORM_PASSWORD=admin
|
||||
USING_UGC_INDEX=False
|
||||
LINDORM_USING_UGC=True
|
||||
LINDORM_QUERY_TIMEOUT=1
|
||||
|
||||
# OceanBase Vector configuration
|
||||
@ -340,6 +348,15 @@ OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
|
||||
# AlibabaCloud MySQL Vector configuration
|
||||
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||
ALIBABACLOUD_MYSQL_PORT=3306
|
||||
ALIBABACLOUD_MYSQL_USER=root
|
||||
ALIBABACLOUD_MYSQL_PASSWORD=root
|
||||
ALIBABACLOUD_MYSQL_DATABASE=dify
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
|
||||
ALIBABACLOUD_MYSQL_HNSW_M=6
|
||||
|
||||
# openGauss configuration
|
||||
OPENGAUSS_HOST=127.0.0.1
|
||||
OPENGAUSS_PORT=6600
|
||||
@ -356,6 +373,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Comma-separated list of file extensions blocked from upload for security reasons.
|
||||
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
|
||||
# Empty by default to allow all file types.
|
||||
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
|
||||
UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
@ -405,6 +428,9 @@ SSRF_DEFAULT_TIME_OUT=5
|
||||
SSRF_DEFAULT_CONNECT_TIME_OUT=5
|
||||
SSRF_DEFAULT_READ_TIME_OUT=5
|
||||
SSRF_DEFAULT_WRITE_TIME_OUT=5
|
||||
SSRF_POOL_MAX_CONNECTIONS=100
|
||||
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||
SSRF_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
|
||||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
@ -415,10 +441,17 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
# CODE EXECUTION CONFIGURATION
|
||||
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
||||
CODE_EXECUTION_API_KEY=dify-sandbox
|
||||
CODE_EXECUTION_SSL_VERIFY=True
|
||||
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
|
||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT=10
|
||||
CODE_EXECUTION_READ_TIMEOUT=60
|
||||
CODE_EXECUTION_WRITE_TIMEOUT=10
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=80000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
|
||||
CODE_MAX_STRING_LENGTH=400000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=400000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH=30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
@ -435,6 +468,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@ -458,9 +494,18 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
# Workflow storage configuration
|
||||
# Options: rdbms, hybrid
|
||||
# rdbms: Use only the relational database (default)
|
||||
@ -481,7 +526,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
# Workflow log cleanup configuration
|
||||
# Enable automatic cleanup of workflow run logs to manage database size
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=true
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=false
|
||||
# Number of days to retain workflow run logs (default: 30 days)
|
||||
WORKFLOW_LOG_RETENTION_DAYS=30
|
||||
# Batch size for workflow log cleanup operations (default: 100)
|
||||
@ -503,6 +548,12 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
@ -574,3 +625,9 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
105
api/.importlinter
Normal file
105
api/.importlinter
Normal file
@ -0,0 +1,105 @@
|
||||
[importlinter]
|
||||
root_packages =
|
||||
core
|
||||
configs
|
||||
controllers
|
||||
models
|
||||
tasks
|
||||
services
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
type=layers
|
||||
layers =
|
||||
graph_engine
|
||||
graph_events
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||
|
||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:worker]
|
||||
name = Worker
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
worker
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:graph-engine-architecture]
|
||||
name = Graph Engine Architecture
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
orchestration
|
||||
command_processing
|
||||
event_management
|
||||
error_handler
|
||||
graph_traversal
|
||||
graph_state_manager
|
||||
worker_management
|
||||
domain
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:domain-isolation]
|
||||
name = Domain Model Isolation
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.domain
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
core.workflow.graph_engine.command_channels
|
||||
core.workflow.graph_engine.layers
|
||||
core.workflow.graph_engine.protocols
|
||||
|
||||
[importlinter:contract:worker-management]
|
||||
name = Worker Management
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.orchestration
|
||||
core.workflow.graph_engine.command_processing
|
||||
core.workflow.graph_engine.event_management
|
||||
|
||||
|
||||
[importlinter:contract:graph-traversal-components]
|
||||
name = Graph Traversal Components
|
||||
type = layers
|
||||
layers =
|
||||
edge_processor
|
||||
skip_propagator
|
||||
containers =
|
||||
core.workflow.graph_engine.graph_traversal
|
||||
|
||||
[importlinter:contract:command-channels]
|
||||
name = Command Channels Independence
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
@ -5,7 +5,7 @@ line-length = 120
|
||||
quote-style = "double"
|
||||
|
||||
[lint]
|
||||
preview = false
|
||||
preview = true
|
||||
select = [
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
@ -30,6 +30,7 @@ select = [
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"T201", # print-found
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
@ -45,6 +46,7 @@ select = [
|
||||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
"UP042", # use StrEnum
|
||||
]
|
||||
|
||||
ignore = [
|
||||
@ -64,6 +66,7 @@ ignore = [
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B901", # allow return in yield
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
@ -78,7 +81,6 @@ ignore = [
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
@ -89,11 +91,18 @@ ignore = [
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"core/model_runtime/callbacks/base_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = [
|
||||
"T201",
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests
|
||||
]
|
||||
|
||||
[lint.pyflakes]
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
62
api/AGENTS.md
Normal file
62
api/AGENTS.md
Normal file
@ -0,0 +1,62 @@
|
||||
# Agent Skill Index
|
||||
|
||||
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Platform Foundations
|
||||
|
||||
- **[Infrastructure Overview](agent_skills/infra.md)**\
|
||||
When to read this:
|
||||
|
||||
- You need to understand where a feature belongs in the architecture.
|
||||
- You’re wiring storage, Redis, vector stores, or OTEL.
|
||||
- You’re about to add CLI commands or async jobs.\
|
||||
What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
|
||||
|
||||
- **[Coding Style](agent_skills/coding_style.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re writing or reviewing backend code and need the authoritative checklist.
|
||||
- You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
|
||||
- You want the exact lint/type/test commands used in PRs.\
|
||||
Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Plugin & Extension Development
|
||||
|
||||
- **[Plugin Systems](agent_skills/plugin.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re building or debugging a marketplace plugin.
|
||||
- You need to know how manifests, providers, daemons, and migrations fit together.\
|
||||
What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
|
||||
|
||||
- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
|
||||
When to read this:
|
||||
|
||||
- You must integrate OAuth for a plugin or datasource.
|
||||
- You’re handling credential encryption or refresh flows.\
|
||||
Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Workflow Entry & Execution
|
||||
|
||||
- **[Trigger Concepts](agent_skills/trigger.md)**\
|
||||
When to read this:
|
||||
- You’re debugging why a workflow didn’t start.
|
||||
- You’re adding a new trigger type or hook.
|
||||
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
|
||||
Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Additional Notes for Agents
|
||||
|
||||
- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
|
||||
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
|
||||
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
|
||||
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
|
||||
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
|
||||
@ -15,7 +15,11 @@ FROM base AS packages
|
||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
@ -49,7 +53,9 @@ RUN \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
curl nodejs \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install fonts to support the use of tools like pypdfium2
|
||||
|
||||
@ -26,6 +26,10 @@
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
|
||||
1. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
@ -80,10 +84,10 @@
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||
```
|
||||
|
||||
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery beat
|
||||
|
||||
115
api/agent_skills/coding_style.md
Normal file
115
api/agent_skills/coding_style.md
Normal file
@ -0,0 +1,115 @@
|
||||
## Linter
|
||||
|
||||
- Always follow `.ruff.toml`.
|
||||
- Run `uv run ruff check --fix --unsafe-fixes`.
|
||||
- Keep each line under 100 characters (including spaces).
|
||||
|
||||
## Code Style
|
||||
|
||||
- `snake_case` for variables and functions.
|
||||
- `PascalCase` for classes.
|
||||
- `UPPER_CASE` for constants.
|
||||
|
||||
## Rules
|
||||
|
||||
- Use Pydantic v2 standard.
|
||||
- Use `uv` for package management.
|
||||
- Do not override dunder methods like `__init__`, `__iadd__`, etc.
|
||||
- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
|
||||
- Prefer simple functions over classes for lightweight helpers.
|
||||
- Keep files below 800 lines; split when necessary.
|
||||
- Keep code readable—no clever hacks.
|
||||
- Never use `print`; log with `logger = logging.getLogger(__name__)`.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
- Mirror the project’s layered architecture: controller → service → core/domain.
|
||||
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
|
||||
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
|
||||
|
||||
## SQLAlchemy Patterns
|
||||
|
||||
- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
|
||||
|
||||
- Open sessions with context managers:
|
||||
|
||||
```python
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
|
||||
|
||||
- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
|
||||
|
||||
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
|
||||
|
||||
## Storage & External IO
|
||||
|
||||
- Access storage via `extensions.ext_storage.storage`.
|
||||
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
|
||||
- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
|
||||
|
||||
## Pydantic Usage
|
||||
|
||||
- Define DTOs with Pydantic v2 models and forbid extras by default.
|
||||
|
||||
- Use `@field_validator` / `@model_validator` for domain rules.
|
||||
|
||||
- Example:
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
|
||||
|
||||
class TriggerConfig(BaseModel):
|
||||
endpoint: HttpUrl
|
||||
secret: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("secret")
|
||||
def ensure_secret_prefix(cls, value: str) -> str:
|
||||
if not value.startswith("dify_"):
|
||||
raise ValueError("secret must start with dify_")
|
||||
return value
|
||||
```
|
||||
|
||||
## Generics & Protocols
|
||||
|
||||
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
|
||||
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
|
||||
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
|
||||
|
||||
## Error Handling & Logging
|
||||
|
||||
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
|
||||
- Declare `logger = logging.getLogger(__name__)` at module top.
|
||||
- Include tenant/app/workflow identifiers in log context.
|
||||
- Log retryable events at `warning`, terminal failures at `error`.
|
||||
|
||||
## Tooling & Checks
|
||||
|
||||
- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
|
||||
- Type checks: `uv run --directory api --dev basedpyright`.
|
||||
- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||
- Run all of the above before submitting your work.
|
||||
|
||||
## Controllers & Services
|
||||
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
|
||||
- Document non-obvious behaviour with concise comments.
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
- Use `configs.dify_config` for configuration—never read environment variables directly.
|
||||
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
|
||||
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
|
||||
- Keep experimental scripts under `dev/`; do not ship them in production builds.
|
||||
96
api/agent_skills/infra.md
Normal file
96
api/agent_skills/infra.md
Normal file
@ -0,0 +1,96 @@
|
||||
## Configuration
|
||||
|
||||
- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
|
||||
- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
|
||||
- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
|
||||
- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
|
||||
|
||||
## Dependencies
|
||||
|
||||
- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
|
||||
- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
|
||||
- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
|
||||
|
||||
## Storage & Files
|
||||
|
||||
- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
|
||||
- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
|
||||
- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
|
||||
- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
|
||||
|
||||
## Redis & Shared State
|
||||
|
||||
- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
|
||||
- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
|
||||
|
||||
## Models
|
||||
|
||||
- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
|
||||
- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
|
||||
- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
|
||||
- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
|
||||
|
||||
## Vector Stores
|
||||
|
||||
- Vector client implementations live in `core/rag/datasource/vdb/<provider>`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
|
||||
- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
|
||||
- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
|
||||
- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
|
||||
|
||||
## Observability & OTEL
|
||||
|
||||
- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
|
||||
- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
|
||||
- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
|
||||
- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
|
||||
|
||||
## Ops Integrations
|
||||
|
||||
- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
|
||||
- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
|
||||
- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
|
||||
|
||||
## Controllers, Services, Core
|
||||
|
||||
- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
|
||||
- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
|
||||
- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
|
||||
|
||||
## Plugins, Tools, Providers
|
||||
|
||||
- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
|
||||
- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
|
||||
- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
|
||||
- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
|
||||
- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
|
||||
- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
|
||||
|
||||
## Async Workloads
|
||||
|
||||
see `agent_skills/trigger.md` for more detailed documentation.
|
||||
|
||||
- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
|
||||
- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
|
||||
- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
|
||||
- Generate migrations with `uv run --project api flask db revision --autogenerate -m "<summary>"`, then review the diff; never hand-edit the database outside Alembic.
|
||||
- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
|
||||
- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask <command>`.
|
||||
- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
|
||||
- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
|
||||
- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
|
||||
- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
|
||||
|
||||
## When You Add Features
|
||||
|
||||
- Check for an existing helper or service before writing a new util.
|
||||
- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
|
||||
- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
|
||||
- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.
|
||||
1
api/agent_skills/plugin.md
Normal file
1
api/agent_skills/plugin.md
Normal file
@ -0,0 +1 @@
|
||||
// TBD
|
||||
1
api/agent_skills/plugin_oauth.md
Normal file
1
api/agent_skills/plugin_oauth.md
Normal file
@ -0,0 +1 @@
|
||||
// TBD
|
||||
53
api/agent_skills/trigger.md
Normal file
53
api/agent_skills/trigger.md
Normal file
@ -0,0 +1,53 @@
|
||||
## Overview
|
||||
|
||||
Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
|
||||
|
||||
## Trigger nodes
|
||||
|
||||
- `UserInput`
|
||||
- `Trigger Webhook`
|
||||
- `Trigger Schedule`
|
||||
- `Trigger Plugin`
|
||||
|
||||
### UserInput
|
||||
|
||||
Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
|
||||
|
||||
1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
|
||||
1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
|
||||
1. For its detailed implementation, please refer to `core/workflow/nodes/start`
|
||||
|
||||
### Trigger Webhook
|
||||
|
||||
Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
|
||||
|
||||
Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
|
||||
|
||||
### Trigger Schedule
|
||||
|
||||
`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
|
||||
|
||||
To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
|
||||
|
||||
### Trigger Plugin
|
||||
|
||||
`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
|
||||
|
||||
1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
|
||||
1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
|
||||
|
||||
A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
|
||||
|
||||
## Worker Pool / Async Task
|
||||
|
||||
All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
|
||||
|
||||
The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
|
||||
|
||||
## Debug Strategy
|
||||
|
||||
Dify divided users into 2 groups: builders / end users.
|
||||
|
||||
Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
|
||||
|
||||
A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.
|
||||
26
api/app.py
26
api/app.py
@ -1,8 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def is_db_command():
|
||||
def is_db_command() -> bool:
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
@ -14,23 +13,12 @@ if is_db_command():
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey
|
||||
|
||||
# gevent
|
||||
monkey.patch_all()
|
||||
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
|
||||
import psycogreen.gevent # type: ignore
|
||||
|
||||
psycogreen.gevent.patch_psycopg()
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
#
|
||||
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
|
||||
#
|
||||
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
||||
13
api/celery_entrypoint.py
Normal file
13
api/celery_entrypoint.py
Normal file
@ -0,0 +1,13 @@
|
||||
import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
print("gRPC patched with gevent.", flush=True) # noqa: T201
|
||||
pscycogreen_gevent.patch_psycopg()
|
||||
print("psycopg2 patched with gevent.", flush=True) # noqa: T201
|
||||
|
||||
|
||||
from app import app, celery
|
||||
|
||||
__all__ = ["app", "celery"]
|
||||
7
api/cnt_base.sh
Executable file
7
api/cnt_base.sh
Executable file
@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
for pattern in "Base" "TypeBase"; do
|
||||
printf "%s " "$pattern"
|
||||
grep "($pattern):" -r --include='*.py' --exclude-dir=".venv" --exclude-dir="tests" . | wc -l
|
||||
done
|
||||
731
api/commands.py
731
api/commands.py
@ -10,10 +10,13 @@ from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
@ -23,19 +26,25 @@ from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -53,31 +62,30 @@ def reset_password(email, new_password, password_confirm):
|
||||
if str(new_password).strip() != str(password_confirm).strip():
|
||||
click.echo(click.style("Passwords do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
valid_password(new_password)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red"))
|
||||
return
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(new_password, salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command("reset-email", help="Reset the account email.")
|
||||
@ -92,22 +100,21 @@ def reset_email(email, new_email, email_confirm):
|
||||
if str(new_email).strip() != str(email_confirm).strip():
|
||||
click.echo(click.style("New emails do not match.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = session.query(Account).where(Account.email == email).one_or_none()
|
||||
|
||||
account = db.session.query(Account).where(Account.email == email).one_or_none()
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
|
||||
if not account:
|
||||
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
|
||||
return
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
email_validate(new_email)
|
||||
except:
|
||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||
return
|
||||
|
||||
account.email = new_email
|
||||
db.session.commit()
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
account.email = new_email
|
||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||
|
||||
|
||||
@click.command(
|
||||
@ -131,25 +138,24 @@ def reset_encrypt_key_pair():
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
|
||||
return
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
tenants = session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
|
||||
tenants = db.session.query(Tenant).all()
|
||||
for tenant in tenants:
|
||||
if not tenant:
|
||||
click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
|
||||
return
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
|
||||
tenant.encrypt_public_key = generate_key_pair(tenant.id)
|
||||
session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
|
||||
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
|
||||
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
|
||||
db.session.commit()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("vdb-migrate", help="Migrate vector db.")
|
||||
@ -174,14 +180,15 @@ def migrate_annotation_vector_database():
|
||||
try:
|
||||
# get apps info
|
||||
per_page = 50
|
||||
apps = (
|
||||
db.session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
apps = (
|
||||
session.query(App)
|
||||
.where(App.status == "normal")
|
||||
.order_by(App.created_at.desc())
|
||||
.limit(per_page)
|
||||
.offset((page - 1) * per_page)
|
||||
.all()
|
||||
)
|
||||
if not apps:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
@ -195,26 +202,27 @@ def migrate_annotation_vector_database():
|
||||
)
|
||||
try:
|
||||
click.echo(f"Creating app annotation index: {app.id}")
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_annotation_setting = (
|
||||
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
|
||||
)
|
||||
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = db.session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
if not app_annotation_setting:
|
||||
skipped_count = skipped_count + 1
|
||||
click.echo(f"App annotation setting disabled: {app.id}")
|
||||
continue
|
||||
# get dataset_collection_binding info
|
||||
dataset_collection_binding = (
|
||||
session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset_collection_binding:
|
||||
click.echo(f"App annotation collection binding not found: {app.id}")
|
||||
continue
|
||||
annotations = session.scalars(
|
||||
select(MessageAnnotation).where(MessageAnnotation.app_id == app.id)
|
||||
).all()
|
||||
dataset = Dataset(
|
||||
id=app.id,
|
||||
tenant_id=app.tenant_id,
|
||||
@ -313,6 +321,8 @@ def migrate_knowledge_vector_database():
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
if not datasets.items:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
@ -731,18 +741,18 @@ where sites.id is null limit 1000"""
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
if not app:
|
||||
print(f"App {app_id} not found")
|
||||
logger.info("App %s not found", app_id)
|
||||
continue
|
||||
|
||||
tenant = app.tenant
|
||||
if tenant:
|
||||
accounts = tenant.get_accounts()
|
||||
if not accounts:
|
||||
print(f"Fix failed for app {app.id}")
|
||||
logger.info("Fix failed for app %s", app.id)
|
||||
continue
|
||||
|
||||
account = accounts[0]
|
||||
print(f"Fixing missing site for app {app.id}")
|
||||
logger.info("Fixing missing site for app %s", app.id)
|
||||
app_was_created.send(app, account=account)
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
@ -953,7 +963,7 @@ def clear_orphaned_file_records(force: bool):
|
||||
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
|
||||
query = "DELETE FROM message_files WHERE id IN :ids"
|
||||
with db.engine.begin() as conn:
|
||||
conn.execute(sa.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
|
||||
conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)})
|
||||
click.echo(
|
||||
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
|
||||
)
|
||||
@ -1219,6 +1229,55 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
@ -1245,15 +1304,17 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
|
||||
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
"""
|
||||
Count orphaned draft variables by app.
|
||||
Count orphaned draft variables by app, including associated file counts.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about orphaned variables
|
||||
Dictionary with statistics about orphaned variables and files
|
||||
"""
|
||||
query = """
|
||||
# Count orphaned variables by app
|
||||
variables_query = """
|
||||
SELECT
|
||||
wdv.app_id,
|
||||
COUNT(*) as variable_count
|
||||
COUNT(*) as variable_count,
|
||||
COUNT(wdv.file_id) as file_count
|
||||
FROM workflow_draft_variables AS wdv
|
||||
WHERE NOT EXISTS(
|
||||
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
|
||||
@ -1263,14 +1324,21 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
with db.engine.connect() as conn:
|
||||
result = conn.execute(sa.text(query))
|
||||
orphaned_by_app = {row[0]: row[1] for row in result}
|
||||
result = conn.execute(sa.text(variables_query))
|
||||
orphaned_by_app = {}
|
||||
total_files = 0
|
||||
|
||||
total_orphaned = sum(orphaned_by_app.values())
|
||||
for row in result:
|
||||
app_id, variable_count, file_count = row
|
||||
orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count}
|
||||
total_files += file_count
|
||||
|
||||
total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values())
|
||||
app_count = len(orphaned_by_app)
|
||||
|
||||
return {
|
||||
"total_orphaned_variables": total_orphaned,
|
||||
"total_orphaned_files": total_files,
|
||||
"orphaned_app_count": app_count,
|
||||
"orphaned_by_app": orphaned_by_app,
|
||||
}
|
||||
@ -1299,6 +1367,7 @@ def cleanup_orphaned_draft_variables(
|
||||
stats = _count_orphaned_draft_variables()
|
||||
|
||||
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
|
||||
logger.info("Found %s associated offload files", stats["total_orphaned_files"])
|
||||
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
|
||||
|
||||
if stats["total_orphaned_variables"] == 0:
|
||||
@ -1307,10 +1376,10 @@ def cleanup_orphaned_draft_variables(
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN: Would delete the following:")
|
||||
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
|
||||
for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[
|
||||
:10
|
||||
]: # Show top 10
|
||||
logger.info(" App %s: %s variables", app_id, count)
|
||||
logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"])
|
||||
if len(stats["orphaned_by_app"]) > 10:
|
||||
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
|
||||
return
|
||||
@ -1319,7 +1388,8 @@ def cleanup_orphaned_draft_variables(
|
||||
if not force:
|
||||
click.confirm(
|
||||
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
|
||||
f"orphaned draft variables from {stats['orphaned_app_count']} apps?",
|
||||
f"orphaned draft variables and {stats['total_orphaned_files']} associated files "
|
||||
f"from {stats['orphaned_app_count']} apps?",
|
||||
abort=True,
|
||||
)
|
||||
|
||||
@ -1352,3 +1422,480 @@ def cleanup_orphaned_draft_variables(
|
||||
continue
|
||||
|
||||
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
|
||||
|
||||
|
||||
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_datasource_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup datasource oauth client
|
||||
"""
|
||||
provider_id = DatasourceProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
|
||||
oauth_client = DatasourceOauthParamConfig(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
system_credentials=client_params_dict,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"provider: {provider_name}", fg="green"))
|
||||
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
|
||||
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
|
||||
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
@click.option(
|
||||
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
|
||||
)
|
||||
def transform_datasource_credentials(environment: str):
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
try:
|
||||
installer_manager = PluginInstaller()
|
||||
plugin_migration = PluginMigration()
|
||||
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
jina_plugin_unique_identifier = None
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for notion_credential in notion_credentials:
|
||||
tenant_id = notion_credential.tenant_id
|
||||
if tenant_id not in notion_credentials_tenant_mapping:
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for firecrawl_credential in firecrawl_credentials:
|
||||
tenant_id = firecrawl_credential.tenant_id
|
||||
if tenant_id not in firecrawl_credentials_tenant_mapping:
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not firecrawl_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red"
|
||||
)
|
||||
)
|
||||
continue
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for jina_credential in jina_credentials:
|
||||
tenant_id = jina_credential.tenant_id
|
||||
if tenant_id not in jina_credentials_tenant_mapping:
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
tenant = db.session.query(Tenant).filter_by(id=tenant_id).first()
|
||||
if not tenant:
|
||||
continue
|
||||
try:
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not jina_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jinareader",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red")
|
||||
)
|
||||
continue
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
|
||||
click.echo(
|
||||
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
|
||||
)
|
||||
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
||||
|
||||
|
||||
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
|
||||
def install_rag_pipeline_plugins(input_file, output_file, workers):
|
||||
"""
|
||||
Install rag pipeline plugins
|
||||
"""
|
||||
click.echo(click.style("Installing rag pipeline plugins", fg="yellow"))
|
||||
plugin_migration = PluginMigration()
|
||||
plugin_migration.install_rag_pipeline_plugins(
|
||||
input_file,
|
||||
output_file,
|
||||
workers,
|
||||
)
|
||||
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
||||
|
||||
|
||||
@click.command(
|
||||
"migrate-oss",
|
||||
help="Migrate files from Local or OpenDAL source to a cloud OSS storage (destination must NOT be local/opendal).",
|
||||
)
|
||||
@click.option(
|
||||
"--path",
|
||||
"paths",
|
||||
multiple=True,
|
||||
help="Storage path prefixes to migrate (repeatable). Defaults: privkeys, upload_files, image_files,"
|
||||
" tools, website_files, keyword_files, ops_trace",
|
||||
)
|
||||
@click.option(
|
||||
"--source",
|
||||
type=click.Choice(["local", "opendal"], case_sensitive=False),
|
||||
default="opendal",
|
||||
show_default=True,
|
||||
help="Source storage type to read from",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite destination if file already exists")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Show what would be migrated without uploading")
|
||||
@click.option("-f", "--force", is_flag=True, help="Skip confirmation and run without prompts")
|
||||
@click.option(
|
||||
"--update-db/--no-update-db",
|
||||
default=True,
|
||||
help="Update upload_files.storage_type from source type to current storage after migration",
|
||||
)
|
||||
def migrate_oss(
|
||||
paths: tuple[str, ...],
|
||||
source: str,
|
||||
overwrite: bool,
|
||||
dry_run: bool,
|
||||
force: bool,
|
||||
update_db: bool,
|
||||
):
|
||||
"""
|
||||
Copy all files under selected prefixes from a source storage
|
||||
(Local filesystem or OpenDAL-backed) into the currently configured
|
||||
destination storage backend, then optionally update DB records.
|
||||
|
||||
Expected usage: set STORAGE_TYPE (and its credentials) to your target backend.
|
||||
"""
|
||||
# Ensure target storage is not local/opendal
|
||||
if dify_config.STORAGE_TYPE in (StorageType.LOCAL, StorageType.OPENDAL):
|
||||
click.echo(
|
||||
click.style(
|
||||
"Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n"
|
||||
"Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n"
|
||||
"volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Default paths if none specified
|
||||
default_paths = ("privkeys", "upload_files", "image_files", "tools", "website_files", "keyword_files", "ops_trace")
|
||||
path_list = list(paths) if paths else list(default_paths)
|
||||
is_source_local = source.lower() == "local"
|
||||
|
||||
click.echo(click.style("Preparing migration to target storage.", fg="yellow"))
|
||||
click.echo(click.style(f"Target storage type: {dify_config.STORAGE_TYPE}", fg="white"))
|
||||
if is_source_local:
|
||||
src_root = dify_config.STORAGE_LOCAL_PATH
|
||||
click.echo(click.style(f"Source: local fs, root: {src_root}", fg="white"))
|
||||
else:
|
||||
click.echo(click.style(f"Source: opendal scheme={dify_config.OPENDAL_SCHEME}", fg="white"))
|
||||
click.echo(click.style(f"Paths to migrate: {', '.join(path_list)}", fg="white"))
|
||||
click.echo("")
|
||||
|
||||
if not force:
|
||||
click.confirm("Proceed with migration?", abort=True)
|
||||
|
||||
# Instantiate source storage
|
||||
try:
|
||||
if is_source_local:
|
||||
src_root = dify_config.STORAGE_LOCAL_PATH
|
||||
source_storage = OpenDALStorage(scheme="fs", root=src_root)
|
||||
else:
|
||||
source_storage = OpenDALStorage(scheme=dify_config.OPENDAL_SCHEME)
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Failed to initialize source storage: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
total_files = 0
|
||||
copied_files = 0
|
||||
skipped_files = 0
|
||||
errored_files = 0
|
||||
copied_upload_file_keys: list[str] = []
|
||||
|
||||
for prefix in path_list:
|
||||
click.echo(click.style(f"Scanning source path: {prefix}", fg="white"))
|
||||
try:
|
||||
keys = source_storage.scan(path=prefix, files=True, directories=False)
|
||||
except FileNotFoundError:
|
||||
click.echo(click.style(f" -> Skipping missing path: {prefix}", fg="yellow"))
|
||||
continue
|
||||
except NotImplementedError:
|
||||
click.echo(click.style(" -> Source storage does not support scanning.", fg="red"))
|
||||
return
|
||||
except Exception as e:
|
||||
click.echo(click.style(f" -> Error scanning '{prefix}': {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
click.echo(click.style(f"Found {len(keys)} files under {prefix}", fg="white"))
|
||||
|
||||
for key in keys:
|
||||
total_files += 1
|
||||
|
||||
# check destination existence
|
||||
if not overwrite:
|
||||
try:
|
||||
if storage.exists(key):
|
||||
skipped_files += 1
|
||||
continue
|
||||
except Exception as e:
|
||||
# existence check failures should not block migration attempt
|
||||
# but should be surfaced to user as a warning for visibility
|
||||
click.echo(
|
||||
click.style(
|
||||
f" -> Warning: failed target existence check for {key}: {str(e)}",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
copied_files += 1
|
||||
continue
|
||||
|
||||
# read from source and write to destination
|
||||
try:
|
||||
data = source_storage.load_once(key)
|
||||
except FileNotFoundError:
|
||||
errored_files += 1
|
||||
click.echo(click.style(f" -> Missing on source: {key}", fg="yellow"))
|
||||
continue
|
||||
except Exception as e:
|
||||
errored_files += 1
|
||||
click.echo(click.style(f" -> Error reading {key}: {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
try:
|
||||
storage.save(key, data)
|
||||
copied_files += 1
|
||||
if prefix == "upload_files":
|
||||
copied_upload_file_keys.append(key)
|
||||
except Exception as e:
|
||||
errored_files += 1
|
||||
click.echo(click.style(f" -> Error writing {key} to target: {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
click.echo("")
|
||||
click.echo(click.style("Migration summary:", fg="yellow"))
|
||||
click.echo(click.style(f" Total: {total_files}", fg="white"))
|
||||
click.echo(click.style(f" Copied: {copied_files}", fg="green"))
|
||||
click.echo(click.style(f" Skipped: {skipped_files}", fg="white"))
|
||||
if errored_files:
|
||||
click.echo(click.style(f" Errors: {errored_files}", fg="red"))
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry-run complete. No changes were made.", fg="green"))
|
||||
return
|
||||
|
||||
if errored_files:
|
||||
click.echo(
|
||||
click.style(
|
||||
"Some files failed to migrate. Review errors above before updating DB records.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
if update_db and not force:
|
||||
if not click.confirm("Proceed to update DB storage_type despite errors?", default=False):
|
||||
update_db = False
|
||||
|
||||
# Optionally update DB records for upload_files.storage_type (only for successfully copied upload_files)
|
||||
if update_db:
|
||||
if not copied_upload_file_keys:
|
||||
click.echo(click.style("No upload_files copied. Skipping DB storage_type update.", fg="yellow"))
|
||||
else:
|
||||
try:
|
||||
source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL
|
||||
updated = (
|
||||
db.session.query(UploadFile)
|
||||
.where(
|
||||
UploadFile.storage_type == source_storage_type,
|
||||
UploadFile.key.in_(copied_upload_file_keys),
|
||||
)
|
||||
.update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False)
|
||||
)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green"))
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from .app_config import DifyConfig
|
||||
|
||||
dify_config = DifyConfig()
|
||||
dify_config = DifyConfig() # type: ignore
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import (
|
||||
@ -112,6 +113,21 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_POOL_MAX_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent connections for the code execution HTTP client",
|
||||
default=100,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of persistent keep-alive connections for the code execution HTTP client",
|
||||
default=20,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
|
||||
description="Keep-alive expiry in seconds for idle connections (set to None to disable)",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||
description="Maximum allowed numeric value in code execution",
|
||||
default=9223372036854775807,
|
||||
@ -134,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
|
||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||
description="Maximum allowed length for strings in code execution",
|
||||
default=80000,
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||
@ -152,6 +168,38 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
default=1000,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_SSL_VERIFY: bool = Field(
|
||||
description="Enable or disable SSL verification for code execution requests",
|
||||
default=True,
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
)
|
||||
|
||||
|
||||
class AsyncWorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for async workflow
|
||||
"""
|
||||
|
||||
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
|
||||
description="Granularity for async workflow scheduler, "
|
||||
"sometime, few users could block the queue due to some time-consuming tasks, "
|
||||
"to avoid this, workflow can be suspended if needed, to achieve"
|
||||
"this, a time-based checker is required, every granularity seconds, "
|
||||
"the checker will check the workflow queue and suspend the workflow",
|
||||
default=120,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
@ -168,6 +216,11 @@ class PluginConfig(BaseSettings):
|
||||
default="plugin-api-key",
|
||||
)
|
||||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=300.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||
@ -237,6 +290,8 @@ class EndpointConfig(BaseSettings):
|
||||
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
|
||||
)
|
||||
|
||||
TRIGGER_URL: str = Field(description="Template url for triggers", default="http://localhost:5001")
|
||||
|
||||
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
@ -305,12 +360,42 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
|
||||
"Empty by default to allow all file types."
|
||||
),
|
||||
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
|
||||
"""
|
||||
Parse and return the blacklist as a set of lowercase extensions.
|
||||
Returns an empty set if no blacklist is configured.
|
||||
"""
|
||||
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
|
||||
return set()
|
||||
return {
|
||||
ext.strip().lower().strip(".")
|
||||
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
|
||||
if ext.strip()
|
||||
}
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
HTTP-related configurations for the application
|
||||
"""
|
||||
|
||||
COOKIE_DOMAIN: str = Field(
|
||||
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
|
||||
default="",
|
||||
)
|
||||
|
||||
API_COMPRESSION_ENABLED: bool = Field(
|
||||
description="Enable or disable gzip compression for HTTP responses",
|
||||
default=False,
|
||||
@ -341,11 +426,11 @@ class HttpConfig(BaseSettings):
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
@ -403,6 +488,21 @@ class HttpConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
SSRF_POOL_MAX_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent connections for the SSRF HTTP client",
|
||||
default=100,
|
||||
)
|
||||
|
||||
SSRF_POOL_MAX_KEEPALIVE_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of persistent keep-alive connections for the SSRF HTTP client",
|
||||
default=20,
|
||||
)
|
||||
|
||||
SSRF_POOL_KEEPALIVE_EXPIRY: PositiveFloat | None = Field(
|
||||
description="Keep-alive expiry in seconds for idle SSRF connections (set to None to disable)",
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED: bool = Field(
|
||||
description="Enable handling of X-Forwarded-For, X-Forwarded-Proto, and X-Forwarded-Port headers"
|
||||
" when the app is behind a single trusted reverse proxy.",
|
||||
@ -505,6 +605,22 @@ class UpdateConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||
# 1000 KiB
|
||||
1024_000,
|
||||
description="Maximum size for variable to trigger final truncation.",
|
||||
)
|
||||
WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH: PositiveInt = Field(
|
||||
100000,
|
||||
description="maximum length for string to trigger tuncation, measure in number of characters",
|
||||
)
|
||||
WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH: PositiveInt = Field(
|
||||
1000,
|
||||
description="maximum length for array to trigger truncation.",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for workflow execution
|
||||
@ -525,16 +641,38 @@ class WorkflowConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
|
||||
description="Maximum allowed depth for nested parallel executions",
|
||||
default=3,
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
|
||||
description="Maximum number of characters allowed in Template Transform node output",
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
description="Maximum number of workers per GraphEngine instance",
|
||||
default=10,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
|
||||
description="Queue depth threshold that triggers worker scale up",
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
|
||||
description="Seconds of idle time before scaling down workers",
|
||||
default=5.0,
|
||||
ge=0.1,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
@ -673,11 +811,35 @@ class ToolConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TemplateMode(StrEnum):
|
||||
# unsafe mode allows flexible operations in templates, but may cause security vulnerabilities
|
||||
UNSAFE = "unsafe"
|
||||
|
||||
# sandbox mode restricts some unsafe operations like accessing __class__.
|
||||
# however, it is still not 100% safe, for example, cpu exploitation can happen.
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
# templating is disabled
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class MailConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for email services
|
||||
"""
|
||||
|
||||
MAIL_TEMPLATING_MODE: TemplateMode = Field(
|
||||
description="Template mode for email services",
|
||||
default=TemplateMode.SANDBOX,
|
||||
)
|
||||
|
||||
MAIL_TEMPLATING_TIMEOUT: int = Field(
|
||||
description="""
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Only available in sandbox mode.""",
|
||||
default=3,
|
||||
)
|
||||
|
||||
MAIL_TYPE: str | None = Field(
|
||||
description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.",
|
||||
default=None,
|
||||
@ -812,6 +974,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
|
||||
description="Maximum number of segments for dataset segments API (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
@ -887,6 +1054,44 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
default=True,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_INTERVAL: int = Field(
|
||||
description="Trigger provider refresh poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_BATCH_SIZE: int = Field(
|
||||
description="Max trigger subscriptions to process per tick",
|
||||
default=200,
|
||||
)
|
||||
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive credential refresh threshold in seconds",
|
||||
default=180,
|
||||
)
|
||||
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive subscription refresh threshold in seconds",
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@ -985,7 +1190,7 @@ class AccountConfig(BaseSettings):
|
||||
|
||||
|
||||
class WorkflowLogConfig(BaseSettings):
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
|
||||
default=100, description="Batch size for workflow run log cleanup operations"
|
||||
@ -1004,12 +1209,21 @@ class SwaggerUIConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
@ -1028,6 +1242,7 @@ class FeatureConfig(
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
@ -1041,5 +1256,6 @@ class FeatureConfig(
|
||||
CeleryBeatConfig,
|
||||
CeleryScheduleTasksConfig,
|
||||
WorkflowLogConfig,
|
||||
WorkflowVariableTruncationConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -220,11 +220,28 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching pipeline templates
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description="Domain for fetching remote pipeline templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class HostedServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
HostedAnthropicConfig,
|
||||
HostedAzureOpenAiConfig,
|
||||
HostedFetchAppTemplateConfig,
|
||||
HostedFetchPipelineTemplateConfig,
|
||||
HostedMinmaxConfig,
|
||||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
|
||||
@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
|
||||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
@ -144,7 +145,7 @@ class DatabaseConfig(BaseSettings):
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
@ -187,12 +188,17 @@ class DatabaseConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
SQLALCHEMY_POOL_TIMEOUT: NonNegativeInt = Field(
|
||||
description="Number of seconds to wait for a connection from the pool before raising a timeout error.",
|
||||
default=30,
|
||||
)
|
||||
|
||||
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
|
||||
description="Number of processes for the retrieval service, default to CPU cores.",
|
||||
default=os.cpu_count() or 1,
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
@ -216,6 +222,7 @@ class DatabaseConfig(BaseSettings):
|
||||
"connect_args": connect_args,
|
||||
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
|
||||
"pool_reset_on_return": None,
|
||||
"pool_timeout": self.SQLALCHEMY_POOL_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
@ -324,6 +331,7 @@ class MiddlewareConfig(
|
||||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
MilvusConfig,
|
||||
AlibabaCloudMySQLConfig,
|
||||
MyScaleConfig,
|
||||
OpenSearchConfig,
|
||||
OracleConfig,
|
||||
|
||||
54
api/configs/middleware/vdb/alibabacloud_mysql_config.py
Normal file
54
api/configs/middleware/vdb/alibabacloud_mysql_config.py
Normal file
@ -0,0 +1,54 @@
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AlibabaCloudMySQLConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for AlibabaCloud MySQL vector database
|
||||
"""
|
||||
|
||||
ALIBABACLOUD_MYSQL_HOST: str = Field(
|
||||
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
|
||||
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
|
||||
default=3306,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_USER: str = Field(
|
||||
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
|
||||
default="root",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
|
||||
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
|
||||
default="",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
|
||||
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Maximum number of connections in the connection pool",
|
||||
default=5,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
|
||||
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
|
||||
default="utf8mb4",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
|
||||
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
|
||||
"(e.g., 'cosine', 'euclidean')",
|
||||
default="cosine",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
|
||||
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
|
||||
default=6,
|
||||
)
|
||||
@ -41,3 +41,13 @@ class BaiduVectorDBConfig(BaseSettings):
|
||||
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||
default=3,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_ANALYZER: str = Field(
|
||||
description="Analyzer type for inverted index in Baidu Vector Database (default is DEFAULT_ANALYZER)",
|
||||
default="DEFAULT_ANALYZER",
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_INVERTED_INDEX_PARSER_MODE: str = Field(
|
||||
description="Parser mode for inverted index in Baidu Vector Database (default is COARSE_MODE)",
|
||||
default="COARSE_MODE",
|
||||
)
|
||||
|
||||
@ -19,15 +19,15 @@ class LindormConfig(BaseSettings):
|
||||
description="Lindorm password",
|
||||
default=None,
|
||||
)
|
||||
DEFAULT_INDEX_TYPE: str | None = Field(
|
||||
LINDORM_INDEX_TYPE: str | None = Field(
|
||||
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
|
||||
default="hnsw",
|
||||
)
|
||||
DEFAULT_DISTANCE_TYPE: str | None = Field(
|
||||
LINDORM_DISTANCE_TYPE: str | None = Field(
|
||||
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
|
||||
)
|
||||
USING_UGC_INDEX: bool | None = Field(
|
||||
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
|
||||
default=False,
|
||||
LINDORM_USING_UGC: bool | None = Field(
|
||||
description="Using UGC index will store indexes with the same IndexType/Dimension in a single big index.",
|
||||
default=True,
|
||||
)
|
||||
LINDORM_QUERY_TIMEOUT: float | None = Field(description="The lindorm search request timeout (s)", default=2.0)
|
||||
|
||||
@ -37,3 +37,15 @@ class OceanBaseVectorConfig(BaseSettings):
|
||||
"with older versions",
|
||||
default=False,
|
||||
)
|
||||
|
||||
OCEANBASE_FULLTEXT_PARSER: str | None = Field(
|
||||
description=(
|
||||
"Fulltext parser to use for text indexing. "
|
||||
"Built-in options: 'ngram' (N-gram tokenizer for English/numbers), "
|
||||
"'beng' (Basic English tokenizer), 'space' (Space-based tokenizer), "
|
||||
"'ngram2' (Improved N-gram tokenizer), 'ik' (Chinese tokenizer). "
|
||||
"External plugins (require installation): 'japanese_ftparser' (Japanese tokenizer), "
|
||||
"'thai_ftparser' (Thai tokenizer). Default is 'ik'"
|
||||
),
|
||||
default="ik",
|
||||
)
|
||||
|
||||
@ -1,23 +1,24 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AuthMethod(StrEnum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for OpenSearch
|
||||
"""
|
||||
|
||||
class AuthMethod(Enum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
OPENSEARCH_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
||||
default=None,
|
||||
|
||||
@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
|
||||
@ -29,7 +29,7 @@ def no_key_cache_key(namespace: str, key: str) -> str:
|
||||
|
||||
|
||||
# Returns whether the obtained value is obtained, and None if it does not
|
||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
|
||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any:
|
||||
if namespace_cache:
|
||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||
if kv_data is None:
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -30,10 +30,10 @@ class NacosHttpClient:
|
||||
params = {}
|
||||
try:
|
||||
self._inject_auth_info(headers, params)
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
except httpx.RequestError as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
||||
@ -78,7 +78,7 @@ class NacosHttpClient:
|
||||
params = {"username": self.username, "password": self.password}
|
||||
url = "http://" + self.server + "/nacos/v1/auth/login"
|
||||
try:
|
||||
resp = requests.request("POST", url, headers=None, params=params)
|
||||
resp = httpx.request("POST", url, headers=None, params=params)
|
||||
resp.raise_for_status()
|
||||
response_data = resp.json()
|
||||
self.token = response_data.get("accessToken")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from configs import dify_config
|
||||
from libs.collection_utils import convert_to_lower_and_upper_set
|
||||
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UNKNOWN_VALUE = "[__UNKNOWN__]"
|
||||
@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
|
||||
|
||||
|
||||
_doc_extensions: list[str]
|
||||
_doc_extensions: set[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
||||
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
"mdx",
|
||||
"pdf",
|
||||
"html",
|
||||
"htm",
|
||||
"xlsx",
|
||||
"xls",
|
||||
"vtt",
|
||||
"properties",
|
||||
"doc",
|
||||
"docx",
|
||||
"csv",
|
||||
"eml",
|
||||
"msg",
|
||||
"pptx",
|
||||
"xml",
|
||||
"epub",
|
||||
}
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
_doc_extensions.append("ppt")
|
||||
_doc_extensions.add("ppt")
|
||||
else:
|
||||
_doc_extensions = [
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
@ -37,5 +53,18 @@ else:
|
||||
"csv",
|
||||
"vtt",
|
||||
"properties",
|
||||
]
|
||||
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
||||
# console
|
||||
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
||||
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
||||
|
||||
# webapp
|
||||
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
|
||||
COOKIE_NAME_PASSPORT = "passport"
|
||||
|
||||
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
||||
HEADER_NAME_APP_CODE = "X-App-Code"
|
||||
HEADER_NAME_PASSPORT = "X-App-Passport"
|
||||
|
||||
@ -31,3 +31,9 @@ def supported_language(lang):
|
||||
|
||||
error = f"{lang} is not a valid language."
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_valid_language(lang: str | None) -> str:
|
||||
if lang and lang in languages:
|
||||
return lang
|
||||
return languages[0]
|
||||
|
||||
7343
api/constants/pipeline_templates.json
Normal file
7343
api/constants/pipeline_templates.json
Normal file
File diff suppressed because one or more lines are too long
@ -5,9 +5,11 @@ from typing import TYPE_CHECKING
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
|
||||
"""
|
||||
@ -32,3 +34,19 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
||||
code = 415
|
||||
|
||||
|
||||
class BlockedFileExtensionError(BaseHTTPException):
|
||||
error_code = "file_extension_blocked"
|
||||
description = "The file extension is blocked for security reasons."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
|
||||
@ -24,7 +24,7 @@ except ImportError:
|
||||
)
|
||||
else:
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||
magic = None # type: ignore
|
||||
magic = None # type: ignore[assignment]
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -1,31 +1,10 @@
|
||||
from importlib import import_module
|
||||
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
from .app.app_import import AppImportApi, AppImportCheckDependenciesApi, AppImportConfirmApi
|
||||
from .explore.audio import ChatAudioApi, ChatTextApi
|
||||
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from .explore.conversation import (
|
||||
ConversationApi,
|
||||
ConversationListApi,
|
||||
ConversationPinApi,
|
||||
ConversationRenameApi,
|
||||
ConversationUnPinApi,
|
||||
)
|
||||
from .explore.message import (
|
||||
MessageFeedbackApi,
|
||||
MessageListApi,
|
||||
MessageMoreLikeThisApi,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from .explore.workflow import (
|
||||
InstalledAppWorkflowRunApi,
|
||||
InstalledAppWorkflowTaskStopApi,
|
||||
)
|
||||
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
|
||||
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
bp = Blueprint("console", __name__, url_prefix="/console/api")
|
||||
|
||||
api = ExternalApi(
|
||||
@ -35,23 +14,23 @@ api = ExternalApi(
|
||||
description="Console management APIs for app configuration, monitoring, and administration",
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
console_ns = Namespace("console", description="Console management API operations", path="/")
|
||||
|
||||
# File
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
RESOURCE_MODULES = (
|
||||
"controllers.console.app.app_import",
|
||||
"controllers.console.explore.audio",
|
||||
"controllers.console.explore.completion",
|
||||
"controllers.console.explore.conversation",
|
||||
"controllers.console.explore.message",
|
||||
"controllers.console.explore.workflow",
|
||||
"controllers.console.files",
|
||||
"controllers.console.remote_files",
|
||||
)
|
||||
|
||||
# Remote files
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
# Import App
|
||||
api.add_resource(AppImportApi, "/apps/imports")
|
||||
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
|
||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
for module_name in RESOURCE_MODULES:
|
||||
import_module(module_name)
|
||||
|
||||
# Ensure resource modules are imported so route decorators are evaluated.
|
||||
# Import other controllers
|
||||
from . import (
|
||||
admin,
|
||||
@ -61,6 +40,7 @@ from . import (
|
||||
init_validate,
|
||||
ping,
|
||||
setup,
|
||||
spec,
|
||||
version,
|
||||
)
|
||||
|
||||
@ -86,6 +66,7 @@ from .app import (
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
@ -114,6 +95,15 @@ from .datasets import (
|
||||
metadata,
|
||||
website,
|
||||
)
|
||||
from .datasets.rag_pipeline import (
|
||||
datasource_auth,
|
||||
datasource_content_preview,
|
||||
rag_pipeline,
|
||||
rag_pipeline_datasets,
|
||||
rag_pipeline_draft_variable,
|
||||
rag_pipeline_import,
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
@ -137,80 +127,10 @@ from .workspace import (
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
# Explore Audio
|
||||
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
|
||||
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
|
||||
|
||||
# Explore Completion
|
||||
api.add_resource(
|
||||
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
CompletionStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_completion",
|
||||
)
|
||||
api.add_resource(
|
||||
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
|
||||
)
|
||||
api.add_resource(
|
||||
ChatStopApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
|
||||
endpoint="installed_app_stop_chat_completion",
|
||||
)
|
||||
|
||||
# Explore Conversation
|
||||
api.add_resource(
|
||||
ConversationRenameApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
|
||||
endpoint="installed_app_conversation_rename",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
|
||||
endpoint="installed_app_conversation",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
|
||||
endpoint="installed_app_conversation_pin",
|
||||
)
|
||||
api.add_resource(
|
||||
ConversationUnPinApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
|
||||
endpoint="installed_app_conversation_unpin",
|
||||
)
|
||||
|
||||
|
||||
# Explore Message
|
||||
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
|
||||
api.add_resource(
|
||||
MessageFeedbackApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
|
||||
endpoint="installed_app_message_feedback",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageMoreLikeThisApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
|
||||
endpoint="installed_app_more_like_this",
|
||||
)
|
||||
api.add_resource(
|
||||
MessageSuggestedQuestionApi,
|
||||
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
|
||||
endpoint="installed_app_suggested_question",
|
||||
)
|
||||
# Explore Workflow
|
||||
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
api.add_resource(
|
||||
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
|
||||
)
|
||||
|
||||
api.add_namespace(console_ns)
|
||||
|
||||
__all__ = [
|
||||
@ -238,6 +158,8 @@ __all__ = [
|
||||
"datasets",
|
||||
"datasets_document",
|
||||
"datasets_segments",
|
||||
"datasource_auth",
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"extension",
|
||||
@ -263,13 +185,20 @@ __all__ = [
|
||||
"parameter",
|
||||
"ping",
|
||||
"plugin",
|
||||
"rag_pipeline",
|
||||
"rag_pipeline_datasets",
|
||||
"rag_pipeline_draft_variable",
|
||||
"rag_pipeline_import",
|
||||
"rag_pipeline_workflow",
|
||||
"recommended_app",
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
@ -277,5 +206,6 @@ __all__ = [
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
@ -15,6 +15,7 @@ from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
auth_token = extract_access_token(request)
|
||||
if not auth_token:
|
||||
raise Unauthorized("Authorization header is missing.")
|
||||
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
@ -70,15 +61,17 @@ class InsertExploreAppListApi(Resource):
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("desc", type=str, location="json")
|
||||
parser.add_argument("copyright", type=str, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, location="json")
|
||||
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("desc", type=str, location="json")
|
||||
.add_argument("copyright", type=str, location="json")
|
||||
.add_argument("privacy_policy", type=str, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, location="json")
|
||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import flask_restx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
@ -8,12 +7,12 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
from . import api, console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
@ -57,7 +56,9 @@ class BaseApiKeyListResource(Resource):
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
@ -66,13 +67,12 @@ class BaseApiKeyListResource(Resource):
|
||||
return {"items": keys}
|
||||
|
||||
@marshal_with(api_key_fields)
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
@ -89,7 +89,7 @@ class BaseApiKeyListResource(Resource):
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.tenant_id = current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
@ -108,7 +108,8 @@ class BaseApiKeyResource(Resource):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
@ -152,11 +153,6 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
@ -173,11 +169,6 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
@ -202,11 +193,6 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
@ -223,11 +209,6 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
|
||||
@ -5,18 +5,20 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/app/prompt-templates")
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
@api.doc("get_advanced_prompt_templates")
|
||||
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
@ -25,11 +27,6 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("app_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("model_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||
parser.add_argument("model_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
@ -8,17 +8,19 @@ from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
|
||||
class AgentLogApi(Resource):
|
||||
@api.doc("get_agent_logs")
|
||||
@api.doc(description="Get agent execution logs for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@ -27,10 +29,6 @@ class AgentLogApi(Resource):
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -17,6 +16,7 @@ from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@ -42,15 +42,15 @@ class AnnotationReplyActionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
@ -69,10 +69,8 @@ class AppAnnotationSettingDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||
return result, 200
|
||||
@ -98,15 +96,12 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
@ -124,10 +119,8 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id, action):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
@ -159,10 +152,8 @@ class AnnotationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
@ -185,8 +176,10 @@ class AnnotationApi(Resource):
|
||||
api.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
@ -198,25 +191,26 @@ class AnnotationApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# Use request.args.getlist to get annotation_ids array directly
|
||||
@ -249,16 +243,21 @@ class AnnotationExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response = {"data": marshal(annotation_list, annotation_fields)}
|
||||
return response, 200
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@api.doc("update_delete_annotation")
|
||||
@ -267,20 +266,16 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@api.response(200, "Annotation updated successfully", annotation_fields)
|
||||
@api.response(204, "Annotation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
@ -288,10 +283,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
@ -310,10 +303,8 @@ class AnnotationBatchImportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
@ -341,10 +332,8 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
@ -376,10 +365,8 @@ class AnnotationHitHistoryListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -12,14 +10,17 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import App, Workflow
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -28,12 +29,6 @@ from services.feature_service import FeatureService
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@api.doc("list_apps")
|
||||
@ -61,6 +56,7 @@ class AppListApi(Resource):
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def uuid_list(value):
|
||||
try:
|
||||
@ -68,34 +64,36 @@ class AppListApi(Resource):
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"completion",
|
||||
"chat",
|
||||
"advanced-chat",
|
||||
"workflow",
|
||||
"agent-chat",
|
||||
"channel",
|
||||
"all",
|
||||
],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"completion",
|
||||
"chat",
|
||||
"advanced-chat",
|
||||
"workflow",
|
||||
"agent-chat",
|
||||
"channel",
|
||||
"all",
|
||||
],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument("name", type=str, location="args", required=False)
|
||||
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
)
|
||||
parser.add_argument("name", type=str, location="args", required=False)
|
||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
@ -109,6 +107,35 @@ class AppListApi(Resource):
|
||||
if str(app.id) in res:
|
||||
app.access_mode = res[str(app.id)].access_mode
|
||||
|
||||
workflow_capable_app_ids = [
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
trigger_node_types = {
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields), 200
|
||||
|
||||
@api.doc("create_app")
|
||||
@ -134,30 +161,26 @@ class AppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if "mode" not in args or args["mode"] is None:
|
||||
raise BadRequest("mode is required")
|
||||
|
||||
app_service = AppService()
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
if current_user.current_tenant_id is None:
|
||||
raise ValueError("current_user.current_tenant_id cannot be None")
|
||||
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
||||
app = app_service.create_app(current_tenant_id, args, current_user)
|
||||
|
||||
return app, 201
|
||||
|
||||
@ -210,21 +233,20 @@ class AppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
.add_argument("max_active_requests", type=int, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
@ -253,12 +275,9 @@ class AppApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model):
|
||||
"""Delete app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
|
||||
@ -288,28 +307,29 @@ class AppCopyApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def post(self, app_model):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=validate_description_length, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
@ -345,16 +365,15 @@ class AppExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
parser.add_argument("workflow_id", type=str, location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
.add_argument("workflow_id", type=str, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return {
|
||||
@ -364,25 +383,23 @@ class AppExportApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@api.doc("check_app_name")
|
||||
@api.doc(description="Check if app name is available")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
|
||||
@api.expect(parser)
|
||||
@api.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
@ -413,14 +430,13 @@ class AppIconApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
@ -446,13 +462,9 @@ class AppSiteStatus(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enable_site", type=bool, required=True, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
@ -480,11 +492,11 @@ class AppApiStatus(Resource):
|
||||
@marshal_with(app_detail_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enable_api", type=bool, required=True, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
@ -525,13 +537,14 @@ class AppTraceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
# add app trace
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enabled", type=bool, required=True, location="json")
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("enabled", type=bool, required=True, location="json")
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
|
||||
@ -1,54 +1,57 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@api.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("app_id", type=str, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
account = current_user
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
@ -67,47 +70,47 @@ class AppImportApi(Resource):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports/<string:import_id>/confirm")
|
||||
class AppImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
@edit_permission_required
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
|
||||
class AppImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_fields)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
@ -111,11 +111,13 @@ class ChatMessageTextApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
@ -166,8 +168,7 @@ class TextModesApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("language", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
@ -15,7 +15,7 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -64,13 +64,15 @@ class CompletionMessageApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
@ -151,22 +153,19 @@ class ChatMessageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz # pip install pytz
|
||||
from flask_login import current_user
|
||||
import sqlalchemy as sa
|
||||
from flask import abort
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
@ -19,10 +17,10 @@ from fields.conversation_fields import (
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models import Account, Conversation, EndUser, Message, MessageAnnotation
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -56,21 +54,27 @@ class CompletionConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_pagination_fields)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
query = db.select(Conversation).where(
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
@ -83,25 +87,18 @@ class CompletionConversationApi(Resource):
|
||||
)
|
||||
|
||||
account = current_user
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
assert account.timezone is not None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
@ -136,9 +133,8 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_message_detail_fields)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
@ -153,14 +149,12 @@ class CompletionConversationDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@ -205,26 +199,32 @@ class ChatConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_with_summary_pagination_fields)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
location="args",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -236,7 +236,7 @@ class ChatConversationApi(Resource):
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
@ -259,29 +259,22 @@ class ChatConversationApi(Resource):
|
||||
)
|
||||
|
||||
account = current_user
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
assert account.timezone is not None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
@ -308,7 +301,7 @@ class ChatConversationApi(Resource):
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
match args["sort_by"]:
|
||||
case "created_at":
|
||||
@ -340,9 +333,8 @@ class ChatConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_detail_fields)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
@ -357,14 +349,12 @@ class ChatConversationDetailApi(Resource):
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@ -373,6 +363,7 @@ class ChatConversationDetailApi(Resource):
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
|
||||
@ -29,8 +29,7 @@ class ConversationVariablesApi(Resource):
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_fields)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", type=str, location="args")
|
||||
parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
stmt = (
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
@ -12,11 +11,15 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
@ -40,16 +43,18 @@ class RuleGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=account.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
@ -90,17 +95,19 @@ class RuleCodeGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=account.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
@ -137,15 +144,17 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=account.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
)
|
||||
@ -186,28 +195,26 @@ class InstructionGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||
parser.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("current", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||
.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||
.add_argument("current", type=str, required=False, default="", location="json")
|
||||
.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
from models import App, db
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
@ -222,21 +229,21 @@ class InstructionGenerateApi(Resource):
|
||||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
@ -245,7 +252,7 @@ class InstructionGenerateApi(Resource):
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
@ -254,13 +261,14 @@ class InstructionGenerateApi(Resource):
|
||||
)
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
workflow_service=WorkflowService(),
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
except ProviderTokenNotInitError as ex:
|
||||
@ -292,8 +300,7 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
match args["type"]:
|
||||
case "prompt":
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
@ -25,9 +24,9 @@ class AppMCPServerController(Resource):
|
||||
@api.doc(description="Get MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, app_model):
|
||||
@ -48,17 +47,19 @@ class AppMCPServerController(Resource):
|
||||
)
|
||||
@api.response(201, "MCP server configuration created successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@marshal_with(app_server_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
description = args.get("description")
|
||||
@ -71,7 +72,7 @@ class AppMCPServerController(Resource):
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
@ -95,19 +96,20 @@ class AppMCPServerController(Resource):
|
||||
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("id", type=str, required=True, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
.add_argument("status", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
@ -142,13 +144,13 @@ class AppMCPServerRefreshController(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id)
|
||||
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
|
||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not server:
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
@ -16,20 +16,18 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
@ -56,16 +54,19 @@ class ChatMessageListApi(Resource):
|
||||
)
|
||||
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
.add_argument("first_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
conversation = (
|
||||
@ -151,12 +152,13 @@ class MessageFeedbackApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model):
|
||||
if current_user is None:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
@ -190,47 +192,6 @@ class MessageFeedbackApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class MessageAnnotationApi(Resource):
|
||||
@api.doc("create_message_annotation")
|
||||
@api.doc(description="Create message annotation")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@get_app_model
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@api.doc("get_annotation_count")
|
||||
@ -267,6 +228,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
@ -301,12 +263,12 @@ class MessageApi(Resource):
|
||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@api.response(200, "Message retrieved successfully", message_detail_fields)
|
||||
@api.response(404, "Message not found")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(message_detail_fields)
|
||||
def get(self, app_model, message_id):
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@ -14,8 +13,8 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
@ -53,16 +52,14 @@ class ModelConfigResource(Resource):
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
config=cast(dict, request.json),
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
@ -90,16 +87,16 @@ class ModelConfigResource(Resource):
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
@ -124,7 +121,7 @@ class ModelConfigResource(Resource):
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
|
||||
# get tool
|
||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||
@ -133,7 +130,7 @@ class ModelConfigResource(Resource):
|
||||
else:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
@ -141,7 +138,7 @@ class ModelConfigResource(Resource):
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
@ -172,6 +169,8 @@ class ModelConfigResource(Resource):
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
@ -30,8 +30,7 @@ class TraceAppConfigApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -63,9 +62,11 @@ class TraceAppConfigApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
"""Create a new trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -99,9 +100,11 @@ class TraceAppConfigApi(Resource):
|
||||
@account_initialization_required
|
||||
def patch(self, app_id):
|
||||
"""Update an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -129,8 +132,7 @@ class TraceAppConfigApi(Resource):
|
||||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
"""Delete an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -9,30 +8,36 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Account, Site
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("title", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
parser.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
parser.add_argument("copyright", type=str, required=False, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
parser.add_argument(
|
||||
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("title", type=str, required=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, location="json")
|
||||
.add_argument("icon", type=str, required=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
.add_argument("copyright", type=str, required=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
.add_argument(
|
||||
"customize_token_strategy",
|
||||
type=str,
|
||||
choices=["must", "allow", "not_allow"],
|
||||
required=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
)
|
||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -76,9 +81,10 @@ class AppSite(Resource):
|
||||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
@ -107,8 +113,6 @@ class AppSite(Resource):
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@ -131,6 +135,8 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
@ -140,8 +146,6 @@ class AppSiteAccessTokenReset(Resource):
|
||||
raise NotFound
|
||||
|
||||
site.code = Site.generate_code(16)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
@ -12,8 +9,9 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode, Message
|
||||
|
||||
|
||||
@ -37,11 +35,13 @@ class DailyMessageStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
@ -50,29 +50,21 @@ class DailyMessageStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -88,16 +80,19 @@ WHERE
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@api.doc("get_daily_conversation_statistics")
|
||||
@api.doc(description="Get daily conversation statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
@ -108,15 +103,15 @@ class DailyConversationStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
stmt = (
|
||||
sa.select(
|
||||
@ -126,21 +121,13 @@ class DailyConversationStatistic(Resource):
|
||||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||
)
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
if start_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
if end_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
@ -159,11 +146,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
@api.doc("get_daily_terminals_statistics")
|
||||
@api.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
@ -174,11 +157,8 @@ class DailyTerminalsStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
@ -187,29 +167,21 @@ class DailyTerminalsStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -230,11 +202,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
@api.doc("get_daily_token_cost_statistics")
|
||||
@api.doc(description="Get daily token cost statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
@ -245,11 +213,8 @@ class DailyTokenCostStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
@ -259,29 +224,21 @@ class DailyTokenCostStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -304,11 +261,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@api.doc("get_average_session_interaction_statistics")
|
||||
@api.doc(description="Get average session interaction statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
@ -319,11 +272,8 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
@ -340,29 +290,21 @@ FROM
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -394,11 +336,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@api.doc("get_user_satisfaction_rate_statistics")
|
||||
@api.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
@ -409,11 +347,8 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
@ -426,29 +361,21 @@ LEFT JOIN
|
||||
message_feedbacks mf
|
||||
ON mf.message_id=m.id AND mf.rating='like'
|
||||
WHERE
|
||||
m.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -474,11 +401,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@api.doc("get_average_response_time_statistics")
|
||||
@api.doc(description="Get average response time statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
@ -489,11 +412,8 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
@ -502,29 +422,21 @@ class AverageResponseTimeStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -545,11 +457,7 @@ class TokensPerSecondStatistic(Resource):
|
||||
@api.doc("get_tokens_per_second_statistics")
|
||||
@api.doc(description="Get tokens per second statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.expect(parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
@ -560,11 +468,7 @@ class TokensPerSecondStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
account, _ = current_account_with_tenant()
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
@ -576,29 +480,21 @@ class TokensPerSecondStatistic(Resource):
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
|
||||
@ -9,26 +9,36 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.debug.event_selectors import (
|
||||
TriggerDebugEvent,
|
||||
TriggerDebugEventPoller,
|
||||
create_event_poller,
|
||||
select_trigger_debug_events,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -37,6 +47,7 @@ from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
@ -69,15 +80,11 @@ class DraftWorkflowApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
@ -106,27 +113,38 @@ class DraftWorkflowApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Draft workflow synced successfully", workflow_fields)
|
||||
@api.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
api.model(
|
||||
"SyncDraftWorkflowResponse",
|
||||
{
|
||||
"result": fields.String,
|
||||
"hash": fields.String,
|
||||
"updated_at": fields.String,
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid workflow configuration")
|
||||
@api.response(403, "Permission denied")
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
parser.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("features", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("hash", type=str, required=False, location="json")
|
||||
.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
@ -148,10 +166,6 @@ class DraftWorkflowApi(Resource):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
@ -205,24 +219,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json", default="")
|
||||
parser.add_argument("files", type=list, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json", default="")
|
||||
.add_argument("files", type=list, location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -270,18 +281,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -322,18 +328,13 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -374,19 +375,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -427,19 +422,13 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -479,20 +468,17 @@ class DraftWorkflowRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
@ -513,7 +499,7 @@ class DraftWorkflowRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
@api.doc("stop_workflow_task")
|
||||
@api.doc(description="Stop running workflow task")
|
||||
@ -525,18 +511,17 @@ class WorkflowTaskStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -562,21 +547,18 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("query", type=str, required=False, location="json", default="")
|
||||
parser.add_argument("files", type=list, location="json", default=[])
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("query", type=str, required=False, location="json", default="")
|
||||
.add_argument("files", type=list, location="json", default=[])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_inputs = args.get("inputs")
|
||||
@ -604,6 +586,13 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
parser_publish = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish")
|
||||
class PublishedWorkflowApi(Resource):
|
||||
@api.doc("get_published_workflow")
|
||||
@ -616,17 +605,11 @@ class PublishedWorkflowApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# fetch published workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||
@ -634,24 +617,19 @@ class PublishedWorkflowApi(Resource):
|
||||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
|
||||
@api.expect(parser_publish)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_publish.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
@ -669,8 +647,12 @@ class PublishedWorkflowApi(Resource):
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
||||
# Update app_model within the same session to ensure atomicity
|
||||
app_model_in_session = session.get(App, app_model.id)
|
||||
if app_model_in_session:
|
||||
app_model_in_session.workflow_id = workflow.id
|
||||
app_model_in_session.updated_by = current_user.id
|
||||
app_model_in_session.updated_at = naive_utc_now()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
@ -682,7 +664,7 @@ class PublishedWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-block-configs")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@api.doc("get_default_block_configs")
|
||||
@api.doc(description="Get default block configurations for workflow")
|
||||
@ -692,46 +674,37 @@ class DefaultBlockConfigsApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
# Get default block configs
|
||||
workflow_service = WorkflowService()
|
||||
return workflow_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-block-configs/<string:block_type>")
|
||||
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultBlockConfigApi(Resource):
|
||||
@api.doc("get_default_block_config")
|
||||
@api.doc(description="Get default block configuration by type")
|
||||
@api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
|
||||
@api.response(200, "Default block configuration retrieved successfully")
|
||||
@api.response(404, "Block type not found")
|
||||
@api.expect(parser_block)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, block_type: str):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = parser_block.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
|
||||
@ -747,8 +720,18 @@ class DefaultBlockConfigApi(Resource):
|
||||
return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_convert = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
|
||||
class ConvertToWorkflowApi(Resource):
|
||||
@api.expect(parser_convert)
|
||||
@api.doc("convert_to_workflow")
|
||||
@api.doc(description="Convert application to workflow mode")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -759,25 +742,17 @@ class ConvertToWorkflowApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Convert basic mode of chatbot app to workflow mode
|
||||
Convert expert mode of chatbot app to workflow mode
|
||||
Convert Completion App to Workflow App
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_convert.parse_args()
|
||||
else:
|
||||
args = {}
|
||||
|
||||
@ -791,26 +766,18 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/config")
|
||||
class WorkflowConfigApi(Resource):
|
||||
"""Resource for workflow configuration."""
|
||||
|
||||
@api.doc("get_workflow_config")
|
||||
@api.doc(description="Get workflow configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Workflow configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
return {
|
||||
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
|
||||
}
|
||||
parser_workflows = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@api.expect(parser_workflows)
|
||||
@api.doc("get_all_published_workflows")
|
||||
@api.doc(description="Get all published workflows for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -820,24 +787,16 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_pagination_fields)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument("user_id", type=str, required=False, location="args")
|
||||
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
page = int(args.get("page", 1))
|
||||
limit = int(args.get("limit", 10))
|
||||
args = parser_workflows.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
@ -865,7 +824,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<uuid:workflow_id>")
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@api.doc("update_workflow_by_id")
|
||||
@api.doc(description="Update workflow by ID")
|
||||
@ -887,19 +846,17 @@ class WorkflowByIdApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
@ -942,16 +899,11 @@ class WorkflowByIdApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Delete workflow
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
@ -998,3 +950,234 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run")
|
||||
class DraftWorkflowTriggerRunApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_run")
|
||||
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerRunRequest",
|
||||
{
|
||||
"node_id": fields.String(required=True, description="Node ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
poller: TriggerDebugEventPoller = create_event_poller(
|
||||
draft_workflow=draft_workflow,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
event: TriggerDebugEvent | None = None
|
||||
try:
|
||||
event = poller.poll()
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
workflow_args = dict(event.workflow_args)
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
return helper.compact_generate_response(
|
||||
AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
root_node_id=node_id,
|
||||
)
|
||||
)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run")
|
||||
class DraftWorkflowTriggerNodeApi(Resource):
|
||||
"""
|
||||
Single node debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_node")
|
||||
@api.doc(description="Poll for trigger events and execute single node when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.response(200, "Trigger event received and node executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id=node_id)
|
||||
if not node_config:
|
||||
raise ValueError("Node data not found for node %s", node_id)
|
||||
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
|
||||
event: TriggerDebugEvent | None = None
|
||||
# for schedule trigger, when run single node, just execute directly
|
||||
if node_type == NodeType.TRIGGER_SCHEDULE:
|
||||
event = TriggerDebugEvent(
|
||||
workflow_args={},
|
||||
node_id=node_id,
|
||||
)
|
||||
# for other trigger types, poll for the event
|
||||
else:
|
||||
try:
|
||||
poller: TriggerDebugEventPoller = create_event_poller(
|
||||
draft_workflow=draft_workflow,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
event = poller.poll()
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
|
||||
raw_files = event.workflow_args.get("files")
|
||||
files = _parse_file(draft_workflow, raw_files if isinstance(raw_files, list) else None)
|
||||
try:
|
||||
node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=event.workflow_args.get("inputs") or {},
|
||||
account=current_user,
|
||||
query="",
|
||||
files=files,
|
||||
)
|
||||
return jsonable_encoder(node_execution)
|
||||
except Exception as e:
|
||||
logger.exception("Error running draft workflow trigger node")
|
||||
return jsonable_encoder(
|
||||
{"status": "error", "error": "An unexpected error occurred while running the node."}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run-all")
|
||||
class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all
|
||||
"""
|
||||
|
||||
@api.doc("draft_workflow_trigger_run_all")
|
||||
@api.doc(description="Full workflow debug when the start node is a trigger")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerRunAllRequest",
|
||||
{
|
||||
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Full workflow debug when the start node is a trigger
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_ids = args["node_ids"]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
try:
|
||||
trigger_debug_event: TriggerDebugEvent | None = select_trigger_debug_events(
|
||||
draft_workflow=draft_workflow,
|
||||
app_model=app_model,
|
||||
user_id=current_user.id,
|
||||
node_ids=node_ids,
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
if trigger_debug_event is None:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
|
||||
try:
|
||||
workflow_args = dict(trigger_debug_event.workflow_args)
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
root_node_id=trigger_debug_event.node_id,
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger run-all")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
@ -28,6 +28,7 @@ class WorkflowAppLogApi(Resource):
|
||||
"created_at__after": "Filter logs created after this timestamp",
|
||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
||||
"created_by_account": "Filter by account",
|
||||
"detail": "Whether to return detailed logs",
|
||||
"page": "Page number (1-99999)",
|
||||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
@ -42,33 +43,36 @@ class WorkflowAppLogApi(Resource):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("detail", type=bool, location="args", required=False, default=False)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
@ -90,6 +94,7 @@ class WorkflowAppLogApi(Resource):
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
detail=args.detail,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
@ -13,15 +13,16 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models.account import Account
|
||||
from models import Account, App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -56,16 +57,18 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
@ -74,6 +77,22 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
return value_type.exposed_type().value
|
||||
|
||||
|
||||
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
"""Serialize full_content information for large variables."""
|
||||
if not variable.is_truncated():
|
||||
return None
|
||||
|
||||
variable_file = variable.variable_file
|
||||
assert variable_file is not None
|
||||
|
||||
return {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
@ -83,11 +102,13 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"value_type": fields.String(attribute=_serialize_variable_type),
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
full_content=fields.Raw(attribute=_serialize_full_content),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
@ -300,10 +321,11 @@ class VariableApi(Resource):
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
|
||||
@ -9,15 +8,85 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.workflow_run_fields import (
|
||||
advanced_chat_workflow_run_pagination_fields,
|
||||
workflow_run_count_fields,
|
||||
workflow_run_detail_fields,
|
||||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models import Account, App, AppMode, EndUser
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
|
||||
|
||||
def _parse_workflow_run_list_args():
|
||||
"""
|
||||
Parse common arguments for workflow run list endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing last_id, limit, status, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_workflow_run_count_args():
|
||||
"""
|
||||
Parse common arguments for workflow run count endpoints.
|
||||
|
||||
Returns:
|
||||
Parsed arguments containing status, time_range, and triggered_from filters
|
||||
"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"status",
|
||||
type=str,
|
||||
choices=WORKFLOW_RUN_STATUS_CHOICES,
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"time_range",
|
||||
type=time_duration,
|
||||
location="args",
|
||||
required=False,
|
||||
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
|
||||
)
|
||||
.add_argument(
|
||||
"triggered_from",
|
||||
type=str,
|
||||
choices=["debugging", "app-run"],
|
||||
location="args",
|
||||
required=False,
|
||||
help="Filter by trigger source: debugging or app-run",
|
||||
)
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@ -25,6 +94,8 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@api.doc(description="Get advanced chat workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -35,13 +106,64 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
args = _parse_workflow_run_list_args()
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@api.doc("get_advanced_chat_workflow_runs_count")
|
||||
@api.doc(description="Get advanced chat workflow runs count statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(workflow_run_count_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status=args.get("status"),
|
||||
time_range=args.get("time_range"),
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -52,6 +174,8 @@ class WorkflowRunListApi(Resource):
|
||||
@api.doc(description="Get workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -62,13 +186,64 @@ class WorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
args = _parse_workflow_run_list_args()
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
|
||||
result = workflow_run_service.get_paginate_workflow_runs(
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
|
||||
class WorkflowRunCountApi(Resource):
|
||||
@api.doc("get_workflow_runs_count")
|
||||
@api.doc(description="Get workflow runs count statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
|
||||
@api.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
|
||||
@api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_count_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args = _parse_workflow_run_count_args()
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args.get("triggered_from"))
|
||||
if args.get("triggered_from")
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status=args.get("status"),
|
||||
time_range=args.get("time_range"),
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -1,24 +1,26 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_runs_statistic")
|
||||
@api.doc(description="Get workflow daily runs statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -29,64 +31,41 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(id) AS runs
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
response_data = self._workflow_run_repo.get_daily_runs_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||
class WorkflowDailyTerminalsStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_terminals_statistic")
|
||||
@api.doc(description="Get workflow daily terminals statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -97,64 +76,41 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||
class WorkflowDailyTokenCostStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_token_cost_statistic")
|
||||
@api.doc(description="Get workflow daily token cost statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -165,69 +121,41 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) AS token_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
"date": str(i.date),
|
||||
"token_count": i.token_count,
|
||||
}
|
||||
)
|
||||
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
|
||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_average_app_interaction_statistic")
|
||||
@api.doc(description="Get workflow average app interaction statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -238,74 +166,29 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
AVG(sub.interactions) AS interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM
|
||||
workflow_runs c
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY
|
||||
date, c.created_by
|
||||
) sub
|
||||
GROUP BY
|
||||
sub.date"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{start}}", "")
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
)
|
||||
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
145
api/controllers/console/app/workflow_trigger.py
Normal file
145
api/controllers/console/app/workflow_trigger.py
Normal file
@ -0,0 +1,145 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = str(args["node_id"])
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.where(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
P1 = ParamSpec("P1")
|
||||
R1 = TypeVar("R1")
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
assert isinstance(current_user, Account)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
|
||||
@ -7,18 +7,14 @@ from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||
from models.account import AccountStatus
|
||||
from models import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
active_check_parser = reqparse.RequestParser()
|
||||
active_check_parser.add_argument(
|
||||
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
|
||||
)
|
||||
active_check_parser.add_argument(
|
||||
"email", type=email, required=False, nullable=True, location="args", help="Email address"
|
||||
)
|
||||
active_check_parser.add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
|
||||
active_check_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
||||
)
|
||||
|
||||
|
||||
@ -60,15 +56,15 @@ class ActivateCheckApi(Resource):
|
||||
return {"is_valid": False}
|
||||
|
||||
|
||||
active_parser = reqparse.RequestParser()
|
||||
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
active_parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
active_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
)
|
||||
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/activate")
|
||||
@ -103,7 +99,7 @@ class ActivateApi(Resource):
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -1,21 +1,22 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
"sources": [
|
||||
@ -33,41 +34,44 @@ class ApiKeyAuthDataSource(Resource):
|
||||
return {"sources": []}
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source/binding")
|
||||
class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source/<uuid:binding_id>")
|
||||
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
|
||||
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding")
|
||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>")
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
@ -45,6 +44,7 @@ class OAuthDataSource(Resource):
|
||||
@api.response(403, "Admin privileges required")
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
|
||||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.HTTPError as e:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
|
||||
@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailCodeError,
|
||||
@ -19,20 +19,23 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
class EmailRegisterSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
@ -52,15 +55,18 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register/validity")
|
||||
class EmailRegisterCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
@ -92,15 +98,18 @@ class EmailRegisterCheckApi(Resource):
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
@console_ns.route("/email-register")
|
||||
class EmailRegisterResetApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
@ -148,8 +157,3 @@ class EmailRegisterResetApi(Resource):
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return account
|
||||
|
||||
|
||||
api.add_resource(EmailRegisterSendEmailApi, "/email-register/send-email")
|
||||
api.add_resource(EmailRegisterCheckApi, "/email-register/validity")
|
||||
api.add_resource(EmailRegisterResetApi, "/email-register")
|
||||
|
||||
@ -20,7 +20,7 @@ from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -54,9 +54,11 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
@ -111,10 +113,12 @@ class ForgotPasswordCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
@ -169,10 +173,12 @@ class ForgotPasswordResetApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
@ -221,8 +227,3 @@ class ForgotPasswordResetApi(Resource):
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from constants.languages import get_valid_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
@ -26,7 +24,16 @@ from controllers.console.error import (
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
clear_refresh_token_from_cookie,
|
||||
extract_refresh_token,
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountRegisterError
|
||||
@ -34,6 +41,7 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
class LoginApi(Resource):
|
||||
"""Resource for user login."""
|
||||
|
||||
@ -41,11 +49,13 @@ class LoginApi(Resource):
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("password", type=str, required=True, location="json")
|
||||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=str, required=True, location="json")
|
||||
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
@ -88,27 +98,48 @@ class LoginApi(Resource):
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/logout")
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
return {"result": "success"}
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
return {"result": "success"}
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
# Clear cookies on logout
|
||||
clear_access_token_from_cookie(response)
|
||||
clear_refresh_token_from_cookie(response)
|
||||
clear_csrf_token_from_cookie(response)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/reset-password")
|
||||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
@ -130,12 +161,15 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/email-code-login")
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
@ -162,16 +196,21 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@console_ns.route("/email-code-login/validity")
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
language = args["language"]
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
@ -205,7 +244,9 @@ class EmailCodeLoginApi(Resource):
|
||||
if account is None:
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
email=user_email,
|
||||
name=user_email,
|
||||
interface_language=get_valid_language(language),
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
@ -215,25 +256,36 @@ class EmailCodeLoginApi(Resource):
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
# Set HTTP-only secure cookies for tokens
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/refresh-token")
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
# Get refresh token from cookie instead of request body
|
||||
refresh_token = extract_refresh_token(request)
|
||||
|
||||
if not refresh_token:
|
||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||
new_token_pair = AccountService.refresh_token(refresh_token)
|
||||
|
||||
# Create response with new cookies
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
# Update cookies with new tokens
|
||||
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
|
||||
set_access_token_to_cookie(request, response, new_token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "data": str(e)}, 401
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
|
||||
api.add_resource(RefreshTokenApi, "/refresh-token")
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
@ -14,8 +14,12 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from libs.token import (
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from models import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
@ -101,8 +105,10 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
except httpx.RequestError as e:
|
||||
error_text = str(e)
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_text = e.response.text
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
@ -128,11 +134,11 @@ class OAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
|
||||
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
if account.status == AccountStatus.BANNED:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
@ -151,9 +157,12 @@ class OAuthCallback(Resource):
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
||||
)
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
return response
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
import flask_login
|
||||
from flask import jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
from .. import api
|
||||
from .. import console_ns
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@ -24,8 +23,7 @@ T = TypeVar("T")
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
@ -86,12 +84,12 @@ def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProvid
|
||||
return decorated
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider")
|
||||
class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
|
||||
@ -108,13 +106,15 @@ class OAuthServerAppApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/authorize")
|
||||
class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
@ -125,16 +125,19 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/token")
|
||||
class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=False, location="json")
|
||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("grant_type", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=False, location="json")
|
||||
.add_argument("client_secret", type=str, required=False, location="json")
|
||||
.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
)
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -180,6 +183,7 @@ class OAuthServerUserTokenApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/provider/account")
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@ -194,9 +198,3 @@ class OAuthServerUserAccountApi(Resource):
|
||||
"timezone": account.timezone,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(OAuthServerAppApi, "/oauth/provider")
|
||||
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
|
||||
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
|
||||
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
|
||||
|
||||
@ -1,42 +1,43 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@console_ns.route("/billing/subscription")
|
||||
class Subscription(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
args = parser.parse_args()
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_subscription(
|
||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/invoices")
|
||||
class Invoices(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
api.add_resource(Subscription, "/billing/subscription")
|
||||
api.add_resource(Invoices, "/billing/invoices")
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
@ -1,35 +1,31 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import api
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
|
||||
|
||||
@console_ns.route("/compliance/download")
|
||||
class ComplianceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||
|
||||
return BillingService.get_compliance_download_link(
|
||||
doc_name=args.doc_name,
|
||||
account_id=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
ip=ip_address,
|
||||
device_info=device_info,
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(ComplianceApi, "/compliance/download")
|
||||
|
||||
@ -1,37 +1,47 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/data-source/integrates",
|
||||
"/data-source/integrates/<uuid:binding_id>/<string:action>",
|
||||
)
|
||||
class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.tenant_id == current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
).all()
|
||||
@ -104,13 +114,28 @@ class DataSourceApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/notion/pre-import/pages")
|
||||
class DataSourceNotionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if not credential:
|
||||
raise NotFound("Credential not found.")
|
||||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
@ -124,7 +149,7 @@ class DataSourceNotionListApi(Resource):
|
||||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
@ -134,60 +159,82 @@ class DataSourceNotionListApi(Resource):
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = session.scalars(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||
datasource_name="notion_datasource",
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
if credential:
|
||||
datasource_runtime.runtime.credentials = credential
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters={},
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
)
|
||||
try:
|
||||
pages = []
|
||||
workspace_info = {}
|
||||
for message in online_document_result:
|
||||
result = message.result
|
||||
for info in result:
|
||||
workspace_info = {
|
||||
"workspace_id": info.workspace_id,
|
||||
"workspace_name": info.workspace_name,
|
||||
"workspace_icon": info.workspace_icon,
|
||||
}
|
||||
for page in info.pages:
|
||||
page_info = {
|
||||
"page_id": page.page_id,
|
||||
"page_name": page.page_name,
|
||||
"type": page.type,
|
||||
"parent_id": page.parent_id,
|
||||
"is_bound": page.page_id in exist_page_ids,
|
||||
"page_icon": page.page_icon,
|
||||
}
|
||||
pages.append(page_info)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
class DataSourceNotionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).where(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
text_docs = extractor.extract()
|
||||
@ -197,12 +244,14 @@ class DataSourceNotionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
@ -211,21 +260,25 @@ class DataSourceNotionApi(Resource):
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
@ -234,6 +287,7 @@ class DataSourceNotionApi(Resource):
|
||||
return response.model_dump(), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
|
||||
class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -250,6 +304,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync")
|
||||
class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -266,16 +321,3 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>")
|
||||
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages")
|
||||
api.add_resource(
|
||||
DataSourceNotionApi,
|
||||
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
|
||||
api.add_resource(
|
||||
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import flask_restx
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@ -20,32 +20,100 @@ from controllers.console.wraps import (
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get supported retrieval methods based on vector database type.
|
||||
|
||||
Args:
|
||||
vector_type: Vector database type, can be None
|
||||
is_mock: Whether this is a Mock API, affects MILVUS handling
|
||||
|
||||
Returns:
|
||||
Dictionary containing supported retrieval methods
|
||||
|
||||
Raises:
|
||||
ValueError: If vector_type is None or unsupported
|
||||
"""
|
||||
if vector_type is None:
|
||||
raise ValueError("Vector store type is not configured.")
|
||||
|
||||
# Define vector database types that only support semantic search
|
||||
semantic_only_types = {
|
||||
VectorType.RELYT,
|
||||
VectorType.TIDB_VECTOR,
|
||||
VectorType.CHROMA,
|
||||
VectorType.PGVECTO_RS,
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
}
|
||||
|
||||
# Define vector database types that support all retrieval methods
|
||||
full_search_types = {
|
||||
VectorType.QDRANT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.OPENSEARCH,
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.ELASTICSEARCH_JA,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.VASTBASE,
|
||||
VectorType.TIDB_ON_QDRANT,
|
||||
VectorType.LINDORM,
|
||||
VectorType.COUCHBASE,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.OCEANBASE,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.HUAWEI_CLOUD,
|
||||
VectorType.TENCENT,
|
||||
VectorType.MATRIXONE,
|
||||
VectorType.CLICKZETTA,
|
||||
VectorType.BAIDU,
|
||||
VectorType.ALIBABACLOUD_MYSQL,
|
||||
}
|
||||
|
||||
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
full_methods = {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
|
||||
if vector_type == VectorType.MILVUS:
|
||||
return semantic_methods if is_mock else full_methods
|
||||
|
||||
if vector_type in semantic_only_types:
|
||||
return semantic_methods
|
||||
elif vector_type in full_search_types:
|
||||
return full_methods
|
||||
else:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
@console_ns.route("/datasets")
|
||||
@ -68,6 +136,7 @@ class DatasetListApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
ids = request.args.getlist("ids")
|
||||
@ -76,15 +145,15 @@ class DatasetListApi(Resource):
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
|
||||
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
@ -92,7 +161,7 @@ class DatasetListApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
|
||||
for item in data:
|
||||
# convert embedding_model_provider to plugin standard format
|
||||
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
|
||||
@ -137,50 +206,53 @@ class DatasetListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
choices=Dataset.PROVIDER_LIST,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
choices=Dataset.PROVIDER_LIST,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -188,7 +260,7 @@ class DatasetListApi(Resource):
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
@ -216,6 +288,7 @@ class DatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -224,7 +297,7 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = marshal(dataset, dataset_detail_fields)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
@ -235,7 +308,7 @@ class DatasetApi(Resource):
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||
|
||||
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
|
||||
|
||||
@ -281,64 +354,76 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
parser.add_argument(
|
||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
|
||||
parser.add_argument(
|
||||
"external_retrieval_model",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external retrieval model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge id.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(
|
||||
DatasetPermissionEnum.ONLY_ME,
|
||||
DatasetPermissionEnum.ALL_TEAM,
|
||||
DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
.add_argument(
|
||||
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
|
||||
)
|
||||
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
.add_argument(
|
||||
"external_retrieval_model",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external retrieval model.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge id.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid icon info.",
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check embedding model setting
|
||||
if (
|
||||
@ -360,8 +445,8 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = marshal(dataset, dataset_detail_fields)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
@ -385,9 +470,9 @@ class DatasetApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
@ -426,6 +511,7 @@ class DatasetQueryApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -460,32 +546,31 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# validate args
|
||||
DocumentService.estimate_args_validate(args)
|
||||
extract_settings = []
|
||||
if args["info_list"]["data_source_type"] == "upload_file":
|
||||
file_ids = args["info_list"]["file_info_list"]["file_ids"]
|
||||
file_details = db.session.scalars(
|
||||
select(UploadFile).where(
|
||||
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
|
||||
)
|
||||
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
|
||||
).all()
|
||||
|
||||
if file_details is None:
|
||||
@ -494,7 +579,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
datasource_type=DatasourceType.FILE,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
@ -503,15 +588,19 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@ -519,15 +608,17 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
},
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
"url": url,
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": "crawl",
|
||||
"only_main_content": website_info_list["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@ -536,7 +627,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
@ -567,6 +658,7 @@ class DatasetRelatedAppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(related_app_list)
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -598,11 +690,10 @@ class DatasetIndexingStatusApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
@ -654,10 +745,9 @@ class DatasetApiKeyApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(api_key_list)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
return {"items": keys}
|
||||
|
||||
@ -667,17 +757,18 @@ class DatasetApiKeyApi(Resource):
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
|
||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
flask_restx.abort(
|
||||
api.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
@ -685,7 +776,7 @@ class DatasetApiKeyApi(Resource):
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
api_token = ApiToken()
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.tenant_id = current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
@ -705,6 +796,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_key_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id = str(api_key_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
@ -714,7 +806,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
.where(
|
||||
ApiToken.tenant_id == current_user.current_tenant_id,
|
||||
ApiToken.tenant_id == current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
@ -722,7 +814,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
)
|
||||
|
||||
if key is None:
|
||||
flask_restx.abort(404, message="API key not found")
|
||||
api.abort(404, message="API key not found")
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
@ -730,6 +822,19 @@ class DatasetApiDeleteApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
|
||||
class DatasetEnableApiApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, status):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-base-info")
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
@api.doc("get_dataset_api_base_info")
|
||||
@ -752,49 +857,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
@ -807,48 +870,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
VectorType.QDRANT
|
||||
| VectorType.WEAVIATE
|
||||
| VectorType.OPENSEARCH
|
||||
| VectorType.ANALYTICDB
|
||||
| VectorType.MYSCALE
|
||||
| VectorType.ORACLE
|
||||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.ELASTICSEARCH_JA
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.VASTBASE
|
||||
| VectorType.LINDORM
|
||||
| VectorType.OPENGAUSS
|
||||
| VectorType.OCEANBASE
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
| VectorType.CLICKZETTA
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
RetrievalMethod.FULL_TEXT_SEARCH.value,
|
||||
RetrievalMethod.HYBRID_SEARCH.value,
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
@ -883,6 +905,7 @@ class DatasetPermissionUserListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@ -42,7 +43,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
@ -51,8 +52,9 @@ from fields.document_fields import (
|
||||
document_with_segments_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
|
||||
@ -61,6 +63,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -75,12 +78,13 @@ class DocumentResource(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
if document.tenant_id != current_tenant_id:
|
||||
raise Forbidden("No permission.")
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -108,6 +112,7 @@ class GetProcessRuleApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get("document_id")
|
||||
@ -164,6 +169,7 @@ class DatasetDocumentListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
@ -195,7 +201,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
@ -209,13 +215,13 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
if sort == "hit_count":
|
||||
sub_query = (
|
||||
db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
|
||||
sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)),
|
||||
sort_logic(Document.position),
|
||||
)
|
||||
elif sort == "created_at":
|
||||
@ -269,6 +275,7 @@ class DatasetDocumentListApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -285,23 +292,23 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("data_source", type=dict, required=False, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=False, location="json")
|
||||
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=False, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, location="json")
|
||||
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
|
||||
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
@ -368,37 +375,38 @@ class DatasetInitApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
required=True,
|
||||
nullable=False,
|
||||
location="json",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
required=True,
|
||||
nullable=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=args["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=args["embedding_model"],
|
||||
@ -415,7 +423,9 @@ class DatasetInitApi(Resource):
|
||||
|
||||
try:
|
||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
|
||||
tenant_id=current_tenant_id,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -441,6 +451,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
@ -449,7 +460,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
|
||||
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
|
||||
|
||||
@ -469,14 +480,14 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
estimate_response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
[extract_setting],
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
@ -499,18 +510,20 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-estimate")
|
||||
class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, batch):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {}
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
@ -523,7 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -531,7 +544,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
@ -539,13 +552,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@ -553,15 +569,17 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
},
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@ -571,7 +589,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
extract_settings,
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
@ -591,6 +609,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
raise IndexingEstimateError(str(e))
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -714,7 +733,7 @@ class DocumentApi(DocumentResource):
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
@ -727,7 +746,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@ -747,7 +766,7 @@ class DocumentApi(DocumentResource):
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
@ -760,7 +779,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@ -821,6 +840,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
@ -871,6 +891,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
@ -910,6 +931,7 @@ class DocumentMetadataApi(DocumentResource):
|
||||
return {"result": "success", "message": "Document metadata updated."}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
|
||||
class DocumentStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -917,6 +939,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
@ -946,6 +969,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
|
||||
class DocumentPauseApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -979,6 +1003,7 @@ class DocumentPauseApi(DocumentResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
||||
class DocumentRecoverApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1009,6 +1034,7 @@ class DocumentRecoverApi(DocumentResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/retry")
|
||||
class DocumentRetryApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1017,8 +1043,9 @@ class DocumentRetryApi(DocumentResource):
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"document_ids", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -1052,6 +1079,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
class DocumentRenameApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1059,12 +1087,14 @@ class DocumentRenameApi(DocumentResource):
|
||||
@marshal_with(document_fields)
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -1075,12 +1105,14 @@ class DocumentRenameApi(DocumentResource):
|
||||
return document
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@ -1089,7 +1121,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
if document.tenant_id != current_tenant_id:
|
||||
raise Forbidden("No permission.")
|
||||
if document.data_source_type != "website_crawl":
|
||||
raise ValueError("Document is not a website document.")
|
||||
@ -1100,3 +1132,39 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
DocumentService.sync_website_document(dataset_id, document)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log")
|
||||
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
"datasource_info": None,
|
||||
"datasource_type": None,
|
||||
"input_data": None,
|
||||
"datasource_node_id": None,
|
||||
}, 200
|
||||
return {
|
||||
"datasource_info": json.loads(log.datasource_info),
|
||||
"datasource_type": log.datasource_type,
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
ChildChunkDeleteIndexError,
|
||||
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
@ -37,11 +36,14 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -58,13 +60,15 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("limit", type=int, default=20, location="args")
|
||||
parser.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
parser.add_argument("enabled", type=str, default="all", location="args")
|
||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||
parser.add_argument("page", type=int, default=1, location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
.add_argument("enabled", type=str, default="all", location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -78,7 +82,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
)
|
||||
@ -114,6 +118,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -139,6 +145,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||
class DatasetDocumentSegmentApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -146,6 +153,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@ -169,7 +178,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -193,6 +202,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
class DatasetDocumentSegmentAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -201,6 +211,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -218,7 +230,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -234,16 +246,19 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -251,6 +266,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -268,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -283,7 +300,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -296,16 +313,18 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
.add_argument(
|
||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@ -313,6 +332,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -329,7 +350,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -345,6 +366,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -353,6 +378,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -364,8 +391,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"upload_file_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
upload_file_id = args["upload_file_id"]
|
||||
|
||||
@ -384,7 +412,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id
|
||||
str(job_id),
|
||||
upload_file_id,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
@ -393,7 +426,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id):
|
||||
def get(self, job_id=None, dataset_id=None, document_id=None):
|
||||
if job_id is None:
|
||||
raise NotFound("The job does not exist.")
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
@ -403,6 +438,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
|
||||
class ChildChunkAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -411,6 +447,8 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -425,7 +463,7 @@ class ChildChunkAddApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -437,7 +475,7 @@ class ChildChunkAddApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -453,11 +491,13 @@ class ChildChunkAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
@ -466,6 +506,8 @@ class ChildChunkAddApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id, segment_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -482,15 +524,17 @@ class ChildChunkAddApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("limit", type=int, default=20, location="args")
|
||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||
parser.add_argument("page", type=int, default=1, location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -513,6 +557,8 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -529,7 +575,7 @@ class ChildChunkAddApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -542,23 +588,30 @@ class ChildChunkAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"chunks", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>"
|
||||
)
|
||||
class ChildChunkUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -575,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -586,7 +639,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
@ -613,6 +666,8 @@ class ChildChunkUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -629,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -640,7 +695,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
@ -656,37 +711,13 @@ class ChildChunkUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
child_chunk = SegmentService.update_child_chunk(
|
||||
args.get("content"), child_chunk, segment, document, dataset
|
||||
)
|
||||
content = args["content"]
|
||||
child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
|
||||
)
|
||||
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentUpdateApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasetDocumentSegmentBatchImportApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
ChildChunkAddApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
|
||||
)
|
||||
api.add_resource(
|
||||
ChildChunkUpdateApi,
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
|
||||
)
|
||||
|
||||
@ -71,3 +71,9 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
|
||||
error_code = "child_chunk_delete_index_error"
|
||||
description = "Delete child chunk index failed: {message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class PipelineNotFoundError(BaseHTTPException):
|
||||
error_code = "pipeline_not_found"
|
||||
description = "Pipeline not found."
|
||||
code = 404
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -8,14 +7,14 @@ from controllers.console import api, console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
return name
|
||||
@ -37,12 +36,13 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
|
||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||
page, limit, current_user.current_tenant_id, search
|
||||
page, limit, current_tenant_id, search
|
||||
)
|
||||
response = {
|
||||
"data": [item.to_dict() for item in external_knowledge_apis],
|
||||
@ -57,20 +57,23 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -82,7 +85,7 @@ class ExternalApiTemplateListApi(Resource):
|
||||
|
||||
try:
|
||||
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, args=args
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
@ -112,28 +115,31 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
|
||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
args=args,
|
||||
@ -145,13 +151,13 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@ -196,21 +202,24 @@ class ExternalDatasetCreateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -220,7 +229,7 @@ class ExternalDatasetCreateApi(Resource):
|
||||
|
||||
try:
|
||||
dataset = ExternalDatasetService.create_external_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
args=args,
|
||||
)
|
||||
@ -252,6 +261,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -262,10 +272,12 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
@ -301,15 +313,17 @@ class BedrockRetrievalApi(Resource):
|
||||
)
|
||||
@api.response(200, "Bedrock retrieval test completed")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||
parser.add_argument(
|
||||
"query",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||
.add_argument(
|
||||
"query",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||
)
|
||||
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Call the knowledge retrieval service
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
import services
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -20,6 +19,8 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
@ -29,6 +30,7 @@ logger = logging.getLogger(__name__)
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -46,15 +48,17 @@ class DatasetsHitTestingBase:
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
@ -16,6 +15,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
class DatasetMetadataCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -23,11 +23,14 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def post(self, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -50,6 +53,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
class DatasetMetadataApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -57,9 +61,10 @@ class DatasetMetadataApi(Resource):
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
name = args["name"]
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@ -68,7 +73,7 @@ class DatasetMetadataApi(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
|
||||
return metadata, 200
|
||||
|
||||
@setup_required
|
||||
@ -76,6 +81,7 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -87,6 +93,7 @@ class DatasetMetadataApi(Resource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/metadata/built-in")
|
||||
class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -97,12 +104,14 @@ class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -116,30 +125,26 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
class DocumentMetadataEditApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"operation_data", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetMetadataCreateApi, "/datasets/<uuid:dataset_id>/metadata")
|
||||
api.add_resource(DatasetMetadataApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldApi, "/datasets/metadata/built-in")
|
||||
api.add_resource(DatasetMetadataBuiltInFieldActionApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
api.add_resource(DocumentMetadataEditApi, "/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
|
||||
354
api/controllers/console/datasets/rag_pipeline/datasource_auth.py
Normal file
354
api/controllers/console/datasets/rag_pipeline/datasource_auth.py
Normal file
@ -0,0 +1,354 @@
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
oauth_config = DatasourceProviderService().get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
if not oauth_config:
|
||||
raise ValueError(f"No OAuth Client Config for {provider_id}")
|
||||
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_config,
|
||||
)
|
||||
response = make_response(jsonable_encoder(authorization_url_response))
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/callback")
|
||||
class DatasourceOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider_id: str):
|
||||
context_id = request.cookies.get("context_id") or request.args.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
oauth_client_params = datasource_provider_service.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
if not oauth_client_params:
|
||||
raise NotFound()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
oauth_handler = OAuthHandler()
|
||||
oauth_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
credential_id = context.get("credential_id")
|
||||
if credential_id:
|
||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
credential_id=context.get("credential_id"),
|
||||
)
|
||||
else:
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_datasource = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@api.expect(parser_datasource)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_datasource.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
try:
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
parser_datasource_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@api.expect(parser_datasource_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
|
||||
args = parser_datasource_delete.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_datasource_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@api.expect(parser_datasource_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
args = parser_datasource_update.parse_args()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
)
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/list")
|
||||
class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/default-list")
|
||||
class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
parser_datasource_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@api.expect(parser_datasource_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_datasource_custom.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@api.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_default.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_update_name = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@api.expect(parser_update_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_update_name.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
@ -0,0 +1,56 @@
|
||||
from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@api.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run datasource content preview
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
credential_id=args.get("credential_id"),
|
||||
)
|
||||
return preview_content, 200
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user