mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 19:55:06 +08:00
Compare commits
641 Commits
fix/openin
...
deploy/tri
| Author | SHA1 | Date | |
|---|---|---|---|
| 7dfe615613 | |||
| a1a3fa0283 | |||
| ff7344f3d3 | |||
| bcd33be22a | |||
| 991f31f195 | |||
| f8b10c2272 | |||
| 369892634d | |||
| 8e5cb86409 | |||
| a85afe4d07 | |||
| e8f93380d1 | |||
| 0a22bc5d05 | |||
| 626d4f3e35 | |||
| f4db5f9973 | |||
| 70dabe318c | |||
| f94972f662 | |||
| d162f7e5ef | |||
| 2f8cb2a1af | |||
| b91d22375f | |||
| a087ace697 | |||
| 0af8a7b958 | |||
| 861098714b | |||
| 63b345110e | |||
| 247069c7e9 | |||
| bb096f4ae3 | |||
| a37497ffb5 | |||
| 02adf4ff06 | |||
| acbc886ecd | |||
| 0a2d478749 | |||
| 95528ad8e5 | |||
| ddad2460f3 | |||
| a8491c26ea | |||
| 0aed7afdc0 | |||
| 18b800a33b | |||
| c64fe595d3 | |||
| dd3b1ccd45 | |||
| 6f927b4a62 | |||
| c76bb8ffa0 | |||
| 4dcd871cef | |||
| abe1d31ae0 | |||
| 2d71fff2b2 | |||
| c4f61b8ae7 | |||
| c51ab6ec37 | |||
| 1fc2255219 | |||
| 037389137d | |||
| 8cd3e84c06 | |||
| b3c6ac1430 | |||
| 68bb97919a | |||
| f268d7c7be | |||
| d695a79ba1 | |||
| cd5a745bd2 | |||
| 51e5f422c4 | |||
| ec3b2b40c2 | |||
| 67ae3e9253 | |||
| d38e3b7792 | |||
| 43d27edef2 | |||
| 94b87eac72 | |||
| fd31af6012 | |||
| 228deccec2 | |||
| 639f1d31f7 | |||
| ec786fe236 | |||
| fe3a6ef049 | |||
| 8b761319f6 | |||
| 002d8769b0 | |||
| 5aba111297 | |||
| dc9b3a7e03 | |||
| 5f2e0d6347 | |||
| 1f72571c06 | |||
| 820925a866 | |||
| 299bd351fd | |||
| 13bf6547ee | |||
| 1b733abe82 | |||
| 5782e26ab2 | |||
| 38d329e75a | |||
| 58f448a926 | |||
| 7a7fea40d9 | |||
| 0309545ff1 | |||
| 6deabfdad3 | |||
| f9b4c31344 | |||
| 8d8800e632 | |||
| 4ca4493084 | |||
| 7efa0df1fd | |||
| b786e101e5 | |||
| 09a8046b10 | |||
| 2f6b3f1c5f | |||
| 2551f6f279 | |||
| 01afa56166 | |||
| 5815950092 | |||
| 766e16b26f | |||
| 0fdb4e7c12 | |||
| 64babb35e2 | |||
| 38522e5dfa | |||
| 4ccc150fd1 | |||
| a4c57017d5 | |||
| b2a7cec644 | |||
| ddc5cbe865 | |||
| 1e23957657 | |||
| 2731b04ff9 | |||
| e8ca80a61a | |||
| e76129b5a4 | |||
| 6635ea62c2 | |||
| 6b8c649876 | |||
| af587f3869 | |||
| 1c1f124891 | |||
| b353a126d8 | |||
| ef0e1031b0 | |||
| d7010f582f | |||
| d696b9f35e | |||
| 665d49d375 | |||
| 26a1c84881 | |||
| dbecba710b | |||
| 591414307a | |||
| 1241cab113 | |||
| 490b7ac43c | |||
| 0f521b26ae | |||
| e4ec4e1470 | |||
| 203c2f0456 | |||
| b502d30e77 | |||
| a486c47b1e | |||
| f76a3f545c | |||
| b5650b579d | |||
| 83702762c8 | |||
| abc13ef762 | |||
| ce00388278 | |||
| 4a76318877 | |||
| e073e755f9 | |||
| 57b405c4c2 | |||
| 2181ffdc89 | |||
| 82dac2eba0 | |||
| 58be008676 | |||
| eed38c8b2a | |||
| 6bd114285c | |||
| 25698ccd54 | |||
| bb3aa0178d | |||
| 751ce4ec41 | |||
| da98a38b14 | |||
| 034e3e85e9 | |||
| aab95d0626 | |||
| 15ea27868e | |||
| bcbd3de336 | |||
| a0daab2711 | |||
| e1d11681c0 | |||
| 8a995d0c21 | |||
| 6241b87f90 | |||
| 2c9e435558 | |||
| b12057b7e5 | |||
| 2445d04d19 | |||
| a58986eb06 | |||
| a39b151d5f | |||
| 3841e8578f | |||
| e0824c2d93 | |||
| c75a4e6309 | |||
| 1dfde240cb | |||
| c6e6f3b7cb | |||
| ea320ce055 | |||
| 6d3ed468d8 | |||
| a6c6bcf95c | |||
| 63b8bbbab3 | |||
| 33ff01d04c | |||
| ae126fd56f | |||
| 9fed2dc065 | |||
| 2e0964e0b0 | |||
| 237bb4595b | |||
| 4486b54680 | |||
| 1a2f8dfcb4 | |||
| 3c30d0f41b | |||
| 5f61ca5e6f | |||
| 06466cb73a | |||
| c5b6219006 | |||
| ae5b5a6aa9 | |||
| a4c4d18f42 | |||
| 3cf19dc07f | |||
| 73c58e4cbb | |||
| c2043d0f6d | |||
| cad2991946 | |||
| e260815c5e | |||
| b4e7239ac7 | |||
| 4b6f4081d6 | |||
| d1c9183d3b | |||
| 2f9705eb6f | |||
| 0e3fab1f9f | |||
| 2431ddfde6 | |||
| 1e4e963d8c | |||
| 522508df28 | |||
| 859f73c19d | |||
| 82c11e36ea | |||
| a6cd2ad880 | |||
| b2a604b801 | |||
| 7c060fc35c | |||
| 48e39b60a8 | |||
| f038aa4746 | |||
| 4833d39ab3 | |||
| fa910be0f6 | |||
| bc274e7300 | |||
| 7b1fc4d2e6 | |||
| 204d5f1bb9 | |||
| 8fc1c7d994 | |||
| 879869d3e3 | |||
| 1d2cdf3489 | |||
| a5d0e68675 | |||
| 605e543372 | |||
| c432f601ab | |||
| e8d03a422d | |||
| 6be013e072 | |||
| 014cbaf387 | |||
| 1be38183e5 | |||
| 8bab42e224 | |||
| 99e9fc751b | |||
| a1b735a4c0 | |||
| c0b7ffd5d0 | |||
| 012877d8d4 | |||
| 41bb6f3109 | |||
| 88c9b18cb6 | |||
| 6efdc94661 | |||
| 68526c09fc | |||
| a78bc507c0 | |||
| e83c7438cb | |||
| 82068a6918 | |||
| 108bcbeb7c | |||
| c4b02be6d3 | |||
| 30eebf804f | |||
| ad7fdd18d0 | |||
| 5d2fbf5215 | |||
| 4a89403566 | |||
| e0c05b2123 | |||
| 85b99580ea | |||
| 15fbedfcad | |||
| 1e6d0de48b | |||
| cad751c00c | |||
| a47276ac24 | |||
| 20403c69b2 | |||
| ffc04f2a9b | |||
| d1580791e4 | |||
| c74eb4fcf3 | |||
| a798534337 | |||
| 470883858e | |||
| 4f4911686d | |||
| 6d479dcdbb | |||
| 24348c40a6 | |||
| a39b50adbb | |||
| 81832c14ee | |||
| b86022c64a | |||
| 45e816a9f6 | |||
| 667b1c37a3 | |||
| b75d533f9b | |||
| aece55d82f | |||
| c432b398f4 | |||
| 9cb2645793 | |||
| 6ac61bd585 | |||
| b02165ffe6 | |||
| 6c576e2c66 | |||
| b0e7e7752f | |||
| 2799b79e8c | |||
| 805a1479f9 | |||
| fe6538b08d | |||
| 1bbb9d6644 | |||
| 5c06e285ec | |||
| 19c92fd670 | |||
| 6026bd873b | |||
| 1369119a0c | |||
| b76e17b25d | |||
| ca7794305b | |||
| fd255e81e1 | |||
| 09d31d1263 | |||
| 47dc26f011 | |||
| 123bb3ec08 | |||
| 90f77282e3 | |||
| 5208867ccc | |||
| edc7ccc795 | |||
| c9798f6425 | |||
| 20ecf7f1d0 | |||
| 9dcb780fcb | |||
| 1cb7b09933 | |||
| 2c62a77cf4 | |||
| b9bc48d8dd | |||
| ed234e311b | |||
| 9843fec393 | |||
| aa4cabdeb5 | |||
| eea713b668 | |||
| fc62538a94 | |||
| 7994144df7 | |||
| e153c483b6 | |||
| 422bb4d4bb | |||
| 87a80d7613 | |||
| e91105ca87 | |||
| 37903722fe | |||
| f4c82d0010 | |||
| fe50093c18 | |||
| 4317af1e90 | |||
| 61a0fcc2ea | |||
| f627348b11 | |||
| 87fb9a6b69 | |||
| 97a2e2ec2e | |||
| 68d357d7f6 | |||
| a103ad3ee7 | |||
| f65d5a9761 | |||
| 6e0a5f5bbd | |||
| 22f858152f | |||
| 775d2e14fc | |||
| 744b287e67 | |||
| c0fc5d98f0 | |||
| 08ea79d730 | |||
| f31b821cc0 | |||
| 34be16874f | |||
| e9738b891f | |||
| 829796514a | |||
| ef1db35f80 | |||
| f9c67621ca | |||
| e29e8e3180 | |||
| 7a81e720d4 | |||
| 55600c0eb1 | |||
| 35e41d7d68 | |||
| b610cf9a11 | |||
| c8e9edc024 | |||
| 471cd760d7 | |||
| 7f48c57edf | |||
| 6569801162 | |||
| 9dd83f50a7 | |||
| 59c56b1b0d | |||
| 94cd2de940 | |||
| 3c23375607 | |||
| 56047f638f | |||
| 9c01d3e775 | |||
| c85c87f3da | |||
| eaa02e3d55 | |||
| 0219222a60 | |||
| dba659b220 | |||
| ee6458768e | |||
| ed3d02dc6d | |||
| 95471b1188 | |||
| 6190cfbfd8 | |||
| 11f2f95103 | |||
| 2abbc14703 | |||
| b2b2816ade | |||
| 4461df1bd9 | |||
| f7f6b4a8b0 | |||
| 41be581594 | |||
| 20ad5b7ac2 | |||
| a1c0bd7a1c | |||
| fd7c4e8a6d | |||
| 41e549af14 | |||
| b7360140ee | |||
| c71f7c7613 | |||
| c905c47775 | |||
| 4ca7ba000c | |||
| f260627660 | |||
| 1e9142c213 | |||
| 82890fe38e | |||
| 7dc7c8af98 | |||
| addebc465a | |||
| 5ab315aeaf | |||
| f092bc1912 | |||
| 23b49b8304 | |||
| 9e97248ede | |||
| d532b06310 | |||
| 07a2281730 | |||
| 42385f3ffa | |||
| c597234374 | |||
| 3de73f07c6 | |||
| 0caeaf6e5c | |||
| 3395297c3e | |||
| e60a7c7143 | |||
| 0e62a66cc2 | |||
| ff32dff163 | |||
| 543c5236e7 | |||
| 341b3ae7c9 | |||
| f01907aac2 | |||
| a7c855cab8 | |||
| 29afc0657d | |||
| d9860b8907 | |||
| dc1ae57dc6 | |||
| d6bd2a9bdb | |||
| c9eed67cf6 | |||
| 0ded6303c1 | |||
| b6e0abadab | |||
| 43bcf40f80 | |||
| f06025a342 | |||
| 24fb95b050 | |||
| 49fca63927 | |||
| ce5fe86430 | |||
| 666586b59c | |||
| 8a2851551a | |||
| a2fe4a28c3 | |||
| 417ebd160b | |||
| 82be305680 | |||
| 03002f4971 | |||
| 1e7e8a8988 | |||
| a715d5ac23 | |||
| 398c8117fe | |||
| f45c18ee35 | |||
| 15c1db42dd | |||
| a31c01f8d9 | |||
| 62753cdf13 | |||
| dc7ce125ad | |||
| eabdb09f8e | |||
| fa6d03c979 | |||
| 634fb192ef | |||
| a4b38e7521 | |||
| 8ff6de91b0 | |||
| 7fa0ad3161 | |||
| 53b21eea61 | |||
| 2f3a61b51b | |||
| 8bca7814f4 | |||
| 92c81b1833 | |||
| 44553d412c | |||
| 95ce224df0 | |||
| 8555635967 | |||
| e843fe8aa6 | |||
| b198c9474a | |||
| 4bb00b83d9 | |||
| c91cbf6b97 | |||
| f6ede6f1c1 | |||
| 65976b27fe | |||
| 2d73ee64a3 | |||
| c61c2b0abd | |||
| 40d3332690 | |||
| 8e45753c68 | |||
| 73e217ab0d | |||
| 26ff59172e | |||
| bebb4ffbaa | |||
| 523da66134 | |||
| e1ca7a9bdb | |||
| 9a8cf709ba | |||
| f909040567 | |||
| 845adb664a | |||
| 0c6cae2d59 | |||
| a893ee0ffc | |||
| 82b63cc6e2 | |||
| c327cfa86e | |||
| 82219c1162 | |||
| cfc3f1527a | |||
| caf1a5fbab | |||
| 4a6398fc1f | |||
| 2bcf96565a | |||
| 9a9d6a4a2b | |||
| 05f66fcf0d | |||
| ea8245a91b | |||
| 759a932bb7 | |||
| fb6f05c267 | |||
| ff9b74efeb | |||
| d6e7543ba6 | |||
| e45d5700ec | |||
| 4e6682bd85 | |||
| 32c715c4d0 | |||
| c11cdf7468 | |||
| 6217c96576 | |||
| 977690590e | |||
| fd845c8b6c | |||
| d7d9abb007 | |||
| 9f22b2726b | |||
| f28b519556 | |||
| 762cf91133 | |||
| 9dd3dcff2b | |||
| 34fbcc9457 | |||
| 9cc8ac981b | |||
| 1153dcef69 | |||
| f811471b18 | |||
| 2382229c7d | |||
| f0e739be43 | |||
| 4dccdf9478 | |||
| 4c37d650d3 | |||
| 1b334e6966 | |||
| d463bd6323 | |||
| 8c298b33cd | |||
| dc1a380888 | |||
| 7e9be4d3d9 | |||
| 5579521ffc | |||
| ab1059134d | |||
| fe2ac66a52 | |||
| f87db2652b | |||
| 3f9f02b9e7 | |||
| 578247ffbc | |||
| 9a5f214623 | |||
| 141ca8904a | |||
| 4488c090b2 | |||
| 59c1fde351 | |||
| cf7ff76165 | |||
| ac79691d69 | |||
| 1a37989769 | |||
| 830f891a74 | |||
| 5937a66e22 | |||
| 894e38f713 | |||
| e4b5b0e5fd | |||
| 598dd1f816 | |||
| 35e24d4d14 | |||
| fea2ffb3ba | |||
| 64f55d55a1 | |||
| bfda4ce7e6 | |||
| 4f7cb7cd2a | |||
| 6517323add | |||
| 531a0b755a | |||
| 91bb8ae4d2 | |||
| 8cafc20098 | |||
| 9d5300440c | |||
| 58524d6d2b | |||
| 19cc6ea993 | |||
| d7f0a31e24 | |||
| 312974aa20 | |||
| d19c100166 | |||
| a8ad80c405 | |||
| 650e38e17f | |||
| 24612adf2c | |||
| 06649f6c21 | |||
| 8b61f5e9c4 | |||
| 6432898e7a | |||
| cced33d068 | |||
| bd01af6415 | |||
| 35011b810d | |||
| f295c7532c | |||
| 7065b67d07 | |||
| c0b50ef61d | |||
| 1d8cca4fa2 | |||
| 3474c179e6 | |||
| 433dad7e1a | |||
| be7ee380bc | |||
| cff5de626b | |||
| 4d8b8f9210 | |||
| a16ef7e73c | |||
| c39dae06d4 | |||
| f906e70f6b | |||
| 5139119307 | |||
| 1b537f904a | |||
| 556b631c54 | |||
| 49df9ceaf3 | |||
| 92ec1ac27a | |||
| e74097afdf | |||
| 8ddc4f2292 | |||
| 7b51320346 | |||
| 9e39be0770 | |||
| 3e5e87930c | |||
| 15a5ba67f1 | |||
| 9e3b4dc90d | |||
| 48c42a9fba | |||
| 0b35bc1ede | |||
| 8e01bb40fe | |||
| 9d21772820 | |||
| b745839bdb | |||
| 59ad6e02ce | |||
| a3b33cbe28 | |||
| 7b8540281a | |||
| 0a6b78f883 | |||
| 56ee8f7d64 | |||
| 3cfcd32876 | |||
| 06dcb55a9d | |||
| ec6cafd7aa | |||
| 6e9858960d | |||
| 150a8276b9 | |||
| c6a90d4bb3 | |||
| c71fd7113c | |||
| 5fc104a992 | |||
| d1de3cfb94 | |||
| 44d36f2460 | |||
| 9088f151d9 | |||
| c692962650 | |||
| f0a60a9000 | |||
| 2f50f3fd4b | |||
| 24cd7bbc62 | |||
| d299e75e1b | |||
| f86b6658c9 | |||
| 0a56d65581 | |||
| dfc03bac9f | |||
| 81e1376e08 | |||
| f50c85d536 | |||
| 5830c69694 | |||
| 0173496a77 | |||
| 30c5b47699 | |||
| e3191d4e91 | |||
| a9b3539b90 | |||
| 5217017e69 | |||
| bd5df5cf1c | |||
| 456dbfe7d7 | |||
| 586f210d6e | |||
| 275a0f9ddd | |||
| cbf2ba6cec | |||
| 1bd621f819 | |||
| bb6a331490 | |||
| 3922ad876f | |||
| fdb53fdeb1 | |||
| 3fb5a7bff1 | |||
| 6157c67cfe | |||
| fbc745764a | |||
| 78f09801b5 | |||
| d0dd81cf84 | |||
| 65b832c46c | |||
| a90b60c36f | |||
| 94a07706ec | |||
| ab2eacb6c1 | |||
| aead192743 | |||
| c1e8584b97 | |||
| 8a2b208299 | |||
| 2b6882bd97 | |||
| aa51662d98 | |||
| 3068526797 | |||
| 298d8c2d88 | |||
| 294e01a8c1 | |||
| 3a5aa4587c | |||
| cf1778e696 | |||
| 54db4c176a | |||
| 5d3e8a31d0 | |||
| 885dff82e3 | |||
| 3c4aa24198 | |||
| 33b0814323 | |||
| 45ae511036 | |||
| 0fa063c640 | |||
| 40d35304ea | |||
| 89821d66bb | |||
| 09d84e900c | |||
| a8746bff30 | |||
| c4d8bf0ce9 | |||
| 9cca605bac | |||
| dbd23f91e5 | |||
| 9387cc088c | |||
| 11f7a89e25 | |||
| 654d522b31 | |||
| 31e6ef77a6 | |||
| e56c847210 | |||
| e00172199a | |||
| 04f47836d8 | |||
| faaca822e4 | |||
| dc0f053925 | |||
| 517726da3a | |||
| 1d6c03eddf | |||
| fdfccd1205 | |||
| b30e7ced0a | |||
| 11770439be | |||
| d89c5f7146 | |||
| 4a475bf1cd | |||
| 10be9cfbbf | |||
| c20e0ad90d | |||
| 22f64d60bb | |||
| 7b7d332239 | |||
| b1d189324a | |||
| 00fb468f2e | |||
| bbbb6e04cb | |||
| f5161d9add | |||
| 787251f00e | |||
| cfe21f0826 | |||
| 196f691865 | |||
| 7a5bb1cfac | |||
| b80d55b764 | |||
| dd71625f52 | |||
| 19936d23d1 |
6
.cursorrules
Normal file
6
.cursorrules
Normal file
@ -0,0 +1,6 @@
|
||||
# Cursor Rules for Dify Project
|
||||
|
||||
## Automated Test Generation
|
||||
|
||||
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
@ -1,4 +1,4 @@
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bullseye
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.12-bookworm
|
||||
|
||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
@ -11,7 +11,7 @@
|
||||
"nodeGypDependencies": true,
|
||||
"version": "lts"
|
||||
},
|
||||
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
|
||||
"ghcr.io/devcontainers-extra/features/npm-package:1": {
|
||||
"package": "typescript",
|
||||
"version": "latest"
|
||||
},
|
||||
|
||||
@ -6,11 +6,10 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ trim_trailing_whitespace = false
|
||||
|
||||
# Matches multiple files with brace expansion notation
|
||||
# Set default charset
|
||||
[*.{js,tsx}]
|
||||
[*.{js,jsx,ts,tsx,mjs}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
|
||||
226
.github/CODEOWNERS
vendored
Normal file
226
.github/CODEOWNERS
vendored
Normal file
@ -0,0 +1,226 @@
|
||||
# CODEOWNERS
|
||||
# This file defines code ownership for the Dify project.
|
||||
# Each line is a file pattern followed by one or more owners.
|
||||
# Owners can be @username, @org/team-name, or email addresses.
|
||||
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||
api/core/workflow/nodes/agent/ @Nov1c444
|
||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
api/core/workflow/nodes/loop/ @Nov1c444
|
||||
api/core/workflow/nodes/llm/ @Nov1c444
|
||||
|
||||
# Backend - RAG (Retrieval Augmented Generation)
|
||||
api/core/rag/ @JohnJyong
|
||||
api/services/rag_pipeline/ @JohnJyong
|
||||
api/services/dataset_service.py @JohnJyong
|
||||
api/services/knowledge_service.py @JohnJyong
|
||||
api/services/external_knowledge_service.py @JohnJyong
|
||||
api/services/hit_testing_service.py @JohnJyong
|
||||
api/services/metadata_service.py @JohnJyong
|
||||
api/services/vector_service.py @JohnJyong
|
||||
api/services/entities/knowledge_entities/ @JohnJyong
|
||||
api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||
api/controllers/console/datasets/ @JohnJyong
|
||||
api/controllers/service_api/dataset/ @JohnJyong
|
||||
api/models/dataset.py @JohnJyong
|
||||
api/tasks/rag_pipeline/ @JohnJyong
|
||||
api/tasks/add_document_to_index_task.py @JohnJyong
|
||||
api/tasks/batch_clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_notion_document_task.py @JohnJyong
|
||||
api/tasks/document_indexing_task.py @JohnJyong
|
||||
api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||
api/tasks/document_indexing_update_task.py @JohnJyong
|
||||
api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||
api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||
api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||
api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||
api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||
api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||
api/tasks/clean_dataset_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
|
||||
# Backend - Plugins
|
||||
api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
|
||||
# Backend - Trigger/Schedule/Webhook
|
||||
api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
api/core/trigger/ @Mairuis @Yeuoly
|
||||
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
api/services/trigger/ @Mairuis @Yeuoly
|
||||
api/models/trigger.py @Mairuis @Yeuoly
|
||||
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Async Workflow
|
||||
api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Billing
|
||||
api/services/billing_service.py @hj24 @zyssyz123
|
||||
api/controllers/console/billing/ @hj24 @zyssyz123
|
||||
|
||||
# Backend - Enterprise
|
||||
api/configs/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/feature_service.py @GarfieldDai @GareArc
|
||||
api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
|
||||
# Backend - Database Migrations
|
||||
api/migrations/ @snakevash @laipz8200
|
||||
|
||||
# Frontend
|
||||
web/ @iamjoel
|
||||
|
||||
# Frontend - App - Orchestration
|
||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Chat
|
||||
web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Completion
|
||||
web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - App - List and Creation
|
||||
web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Monitoring
|
||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Settings
|
||||
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - Hit Testing
|
||||
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - List and Creation
|
||||
web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
||||
web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - RAG - Documents List
|
||||
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Segments List
|
||||
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Settings
|
||||
web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - Ecosystem - Plugins
|
||||
web/app/components/plugins/ @iamjoel @zhsama
|
||||
|
||||
# Frontend - Ecosystem - Tools
|
||||
web/app/components/tools/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Ecosystem - MarketPlace
|
||||
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Login and Registration
|
||||
web/app/signin/ @douxc @iamjoel
|
||||
web/app/signup/ @douxc @iamjoel
|
||||
web/app/reset-password/ @douxc @iamjoel
|
||||
web/app/install/ @douxc @iamjoel
|
||||
web/app/init/ @douxc @iamjoel
|
||||
web/app/forgot-password/ @douxc @iamjoel
|
||||
web/app/account/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Service Authentication
|
||||
web/service/base.ts @douxc @iamjoel
|
||||
|
||||
# Frontend - WebApp Authentication and Access Control
|
||||
web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||
web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Explore Page
|
||||
web/app/components/explore/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Personal Settings
|
||||
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Analytics
|
||||
web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Base Components
|
||||
web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
web/utils/time.ts @iamjoel @zxhlyh
|
||||
web/utils/format.ts @iamjoel @zxhlyh
|
||||
web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Billing and Education
|
||||
web/app/components/billing/ @iamjoel @zxhlyh
|
||||
web/app/education-apply/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Workspace
|
||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
12
.github/copilot-instructions.md
vendored
Normal file
12
.github/copilot-instructions.md
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
# Copilot Instructions
|
||||
|
||||
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
|
||||
|
||||
Key reminders:
|
||||
|
||||
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
|
||||
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
|
||||
- Target >95% line and branch coverage and 100% function/statement coverage.
|
||||
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
|
||||
|
||||
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
|
||||
32
.github/workflows/api-tests.yml
vendored
32
.github/workflows/api-tests.yml
vendored
@ -39,25 +39,11 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev pyrefly
|
||||
uv run pyrefly check || true
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
@ -76,7 +62,7 @@ jobs:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
db_postgres
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
@ -93,3 +79,19 @@ jobs:
|
||||
|
||||
- name: Run TestContainers
|
||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
9
.github/workflows/autofix.yml
vendored
9
.github/workflows/autofix.yml
vendored
@ -2,6 +2,8 @@ name: autofix.ci
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@ -26,10 +28,17 @@ jobs:
|
||||
# Format code
|
||||
uv run ruff format ..
|
||||
|
||||
- name: count migration progress
|
||||
run: |
|
||||
cd api
|
||||
./cnt_base.sh
|
||||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||
cat > /tmp/optional-rule.yml << 'EOF'
|
||||
id: convert-optional-to-union
|
||||
|
||||
3
.github/workflows/build-push.yml
vendored
3
.github/workflows/build-push.yml
vendored
@ -4,8 +4,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- "deploy/**"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
|
||||
61
.github/workflows/db-migration-test.yml
vendored
61
.github/workflows/db-migration-test.yml
vendored
@ -8,7 +8,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
db-migration-test:
|
||||
db-migration-test-postgres:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@ -45,7 +45,7 @@ jobs:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
db_postgres
|
||||
redis
|
||||
|
||||
- name: Prepare configs
|
||||
@ -57,3 +57,60 @@ jobs:
|
||||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
- name: Ensure Offline migration are supported
|
||||
run: |
|
||||
# upgrade
|
||||
uv run --directory api flask db upgrade 'base:head' --sql
|
||||
# downgrade
|
||||
uv run --directory api flask db downgrade 'head:base' --sql
|
||||
|
||||
- name: Prepare middleware env for MySQL
|
||||
run: |
|
||||
cd docker
|
||||
cp middleware.env.example middleware.env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
|
||||
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_mysql
|
||||
redis
|
||||
|
||||
- name: Prepare configs for MySQL
|
||||
run: |
|
||||
cd api
|
||||
cp .env.example .env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
name: Deploy RAG Dev
|
||||
name: Deploy Trigger Dev
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@ -7,7 +7,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
- "deploy/trigger-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -16,12 +16,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
host: ${{ secrets.TRIGGER_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@ -1,6 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||
@ -13,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
|
||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
|
||||
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@ -103,6 +103,11 @@ jobs:
|
||||
run: |
|
||||
pnpm run lint
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@ -20,22 +20,22 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
recent_commit_sha=$(git rev-parse HEAD)
|
||||
second_recent_commit_sha=$(git rev-parse HEAD~1)
|
||||
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
|
||||
git fetch origin "${{ github.event.before }}" || true
|
||||
git fetch origin "${{ github.sha }}" || true
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .ts)
|
||||
file_args="$file_args --file=$filename"
|
||||
file_args="$file_args --file $filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
echo "File arguments: $file_args"
|
||||
@ -77,12 +77,15 @@ jobs:
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: Update i18n files and type definitions based on en-US changes
|
||||
title: 'chore: translate i18n files and update type definitions'
|
||||
commit-message: 'chore(i18n): update translations based on en-US changes'
|
||||
title: 'chore(i18n): translate i18n files and update type definitions'
|
||||
body: |
|
||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
||||
|
||||
|
||||
**Triggered by:** ${{ github.sha }}
|
||||
|
||||
**Changes included:**
|
||||
- Updated translation files for all locales
|
||||
- Regenerated TypeScript type definitions for type safety
|
||||
branch: chore/automated-i18n-updates
|
||||
branch: chore/automated-i18n-updates-${{ github.sha }}
|
||||
delete-branch: true
|
||||
|
||||
18
.github/workflows/vdb-tests.yml
vendored
18
.github/workflows/vdb-tests.yml
vendored
@ -51,13 +51,13 @@ jobs:
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Vector Store (TiDB)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: docker/tidb/docker-compose.yaml
|
||||
services: |
|
||||
tidb
|
||||
tiflash
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
# compose-file: docker/tidb/docker-compose.yaml
|
||||
# services: |
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
@ -83,8 +83,8 @@ jobs:
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Check VDB Ready (TiDB)
|
||||
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@ -6,6 +6,9 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# *db files
|
||||
*.db
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
@ -97,6 +100,7 @@ __pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat-schedule.db
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
@ -182,6 +186,8 @@ docker/volumes/couchbase/*
|
||||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
docker/volumes/matrixone/*
|
||||
docker/volumes/mysql/*
|
||||
docker/volumes/seekdb/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
@ -234,4 +240,7 @@ scripts/stress-test/reports/
|
||||
|
||||
# mcp
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
.serena/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
|
||||
9
.vscode/launch.json.template
vendored
9
.vscode/launch.json.template
vendored
@ -8,8 +8,7 @@
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
"FLASK_ENV": "development"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
@ -28,9 +27,7 @@
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
@ -40,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
5
.windsurf/rules/testing.md
Normal file
5
.windsurf/rules/testing.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Windsurf Testing Rules
|
||||
|
||||
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
|
||||
- Honor every requirement in that document when generating or accepting tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
@ -14,7 +14,7 @@ The codebase is split into:
|
||||
|
||||
- Run backend CLI commands through `uv run --project api <command>`.
|
||||
|
||||
- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review.
|
||||
- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||
|
||||
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
|
||||
|
||||
|
||||
@ -77,6 +77,8 @@ How we prioritize:
|
||||
|
||||
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
|
||||
|
||||
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
|
||||
|
||||
#### Backend
|
||||
|
||||
For setting up the backend service, kindly refer to our detailed [instructions](https://github.com/langgenius/dify/blob/main/api/README.md) in the `api/README.md` file. This document contains step-by-step guidance to help you get the backend up and running smoothly.
|
||||
|
||||
9
Makefile
9
Makefile
@ -26,7 +26,6 @@ prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
@ -71,6 +70,11 @@ type-check:
|
||||
@uv run --directory api --dev basedpyright
|
||||
@echo "✅ Type check complete"
|
||||
|
||||
test:
|
||||
@echo "🧪 Running backend unit tests..."
|
||||
@uv run --project api --dev dev/pytest/pytest_unit_tests.sh
|
||||
@echo "✅ Tests complete"
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
@ -120,6 +124,7 @@ help:
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format and fix code with ruff"
|
||||
@echo " make type-check - Run type checking with basedpyright"
|
||||
@echo " make test - Run backend unit tests"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
@ -129,4 +134,4 @@ help:
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test
|
||||
|
||||
50
README.md
50
README.md
@ -36,22 +36,28 @@
|
||||
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
|
||||
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
|
||||
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Health Score" src="https://insights.linuxfoundation.org/api/badge/health-score?project=langgenius-dify"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Contributors" src="https://insights.linuxfoundation.org/api/badge/contributors?project=langgenius-dify"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Active Contributors" src="https://insights.linuxfoundation.org/api/badge/active-contributors?project=langgenius-dify"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README/README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./README/README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README/README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README/README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README/README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README/README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README/README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./README/README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./README/README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./README/README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./README/README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./README/README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
<a href="./docs/zh-TW/README.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./docs/zh-CN/README.md"><img alt="简体中文文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./docs/ja-JP/README.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./docs/es-ES/README.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./docs/fr-FR/README.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./docs/tlh/README.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./docs/ko-KR/README.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./docs/ar-SA/README.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./docs/tr-TR/README.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./docs/vi-VN/README.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./docs/de-DE/README.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
<a href="./docs/bn-BD/README.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
|
||||
@ -63,7 +69,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
|
||||
> - CPU >= 2 Core
|
||||
> - RAM >= 4 GiB
|
||||
|
||||
</br>
|
||||
<br/>
|
||||
|
||||
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
|
||||
@ -109,15 +115,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
||||
|
||||
## Using Dify
|
||||
|
||||
- **Cloud </br>**
|
||||
- **Cloud <br/>**
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
|
||||
- **Self-hosting Dify Community Edition</br>**
|
||||
- **Self-hosting Dify Community Edition<br/>**
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations</br>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
- **Dify for enterprise / organizations<br/>**
|
||||
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
|
||||
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
@ -129,8 +135,18 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
## Advanced Setup
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
||||
|
||||
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
|
||||
|
||||
### Deployment with Kubernetes
|
||||
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
|
||||
@ -27,6 +27,9 @@ FILES_URL=http://localhost:5001
|
||||
# Example: INTERNAL_FILES_URL=http://api:5001
|
||||
INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# TRIGGER URL
|
||||
TRIGGER_URL=http://localhost:5001
|
||||
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
@ -69,12 +72,15 @@ REDIS_CLUSTERS_PASSWORD=
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
# PostgreSQL database configuration
|
||||
|
||||
# Database configuration
|
||||
DB_TYPE=postgresql
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
|
||||
@ -156,9 +162,11 @@ SUPABASE_URL=your-server-url
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). Leading dots are optional.
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@ -168,6 +176,18 @@ WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
OCEANBASE_FULLTEXT_PARSER=ik
|
||||
SEEKDB_MEMORY_LIMIT=2G
|
||||
|
||||
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=http://localhost:6333
|
||||
@ -334,14 +354,14 @@ LINDORM_PASSWORD=admin
|
||||
LINDORM_USING_UGC=True
|
||||
LINDORM_QUERY_TIMEOUT=1
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
# AlibabaCloud MySQL Vector configuration
|
||||
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||
ALIBABACLOUD_MYSQL_PORT=3306
|
||||
ALIBABACLOUD_MYSQL_USER=root
|
||||
ALIBABACLOUD_MYSQL_PASSWORD=root
|
||||
ALIBABACLOUD_MYSQL_DATABASE=dify
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
|
||||
ALIBABACLOUD_MYSQL_HNSW_M=6
|
||||
|
||||
# openGauss configuration
|
||||
OPENGAUSS_HOST=127.0.0.1
|
||||
@ -359,6 +379,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Comma-separated list of file extensions blocked from upload for security reasons.
|
||||
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
|
||||
# Empty by default to allow all file types.
|
||||
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
|
||||
UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
@ -425,10 +451,13 @@ CODE_EXECUTION_SSL_VERIFY=True
|
||||
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
|
||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT=10
|
||||
CODE_EXECUTION_READ_TIMEOUT=60
|
||||
CODE_EXECUTION_WRITE_TIMEOUT=10
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=80000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=80000
|
||||
CODE_MAX_STRING_LENGTH=400000
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH=400000
|
||||
CODE_MAX_STRING_ARRAY_LENGTH=30
|
||||
CODE_MAX_OBJECT_ARRAY_LENGTH=30
|
||||
CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
@ -445,6 +474,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@ -500,7 +532,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
# Workflow log cleanup configuration
|
||||
# Enable automatic cleanup of workflow run logs to manage database size
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=true
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=false
|
||||
# Number of days to retain workflow run logs (default: 30 days)
|
||||
WORKFLOW_LOG_RETENTION_DAYS=30
|
||||
# Batch size for workflow log cleanup operations (default: 100)
|
||||
@ -508,6 +540,7 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_DEFAULT_ACTIVE_REQUESTS=0
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
# Celery beat configuration
|
||||
@ -522,6 +555,12 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
@ -593,3 +632,9 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
@ -16,6 +16,7 @@ layers =
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
runtime
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
|
||||
@ -81,7 +81,6 @@ ignore = [
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
62
api/AGENTS.md
Normal file
62
api/AGENTS.md
Normal file
@ -0,0 +1,62 @@
|
||||
# Agent Skill Index
|
||||
|
||||
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Platform Foundations
|
||||
|
||||
- **[Infrastructure Overview](agent_skills/infra.md)**\
|
||||
When to read this:
|
||||
|
||||
- You need to understand where a feature belongs in the architecture.
|
||||
- You’re wiring storage, Redis, vector stores, or OTEL.
|
||||
- You’re about to add CLI commands or async jobs.\
|
||||
What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
|
||||
|
||||
- **[Coding Style](agent_skills/coding_style.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re writing or reviewing backend code and need the authoritative checklist.
|
||||
- You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
|
||||
- You want the exact lint/type/test commands used in PRs.\
|
||||
Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Plugin & Extension Development
|
||||
|
||||
- **[Plugin Systems](agent_skills/plugin.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re building or debugging a marketplace plugin.
|
||||
- You need to know how manifests, providers, daemons, and migrations fit together.\
|
||||
What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
|
||||
|
||||
- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
|
||||
When to read this:
|
||||
|
||||
- You must integrate OAuth for a plugin or datasource.
|
||||
- You’re handling credential encryption or refresh flows.\
|
||||
Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Workflow Entry & Execution
|
||||
|
||||
- **[Trigger Concepts](agent_skills/trigger.md)**\
|
||||
When to read this:
|
||||
- You’re debugging why a workflow didn’t start.
|
||||
- You’re adding a new trigger type or hook.
|
||||
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
|
||||
Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Additional Notes for Agents
|
||||
|
||||
- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
|
||||
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
|
||||
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
|
||||
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
|
||||
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
|
||||
@ -15,7 +15,11 @@ FROM base AS packages
|
||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
@ -44,14 +48,22 @@ ENV PYTHONIOENCODING=utf-8
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
# Create non-root user
|
||||
ARG dify_uid=1001
|
||||
RUN groupadd -r -g ${dify_uid} dify && \
|
||||
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
|
||||
chown -R dify:dify /app
|
||||
|
||||
RUN \
|
||||
apt-get update \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
curl nodejs \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
expat libldap-2.5-0=2.5.13+dfsg-5 perl libsqlite3-0=3.40.1-2+deb12u2 zlib1g=1:1.2.13.dfsg-1 \
|
||||
# install fonts to support the use of tools like pypdfium2
|
||||
fonts-noto-cjk \
|
||||
# install a package to improve the accuracy of guessing mime type and file extension
|
||||
@ -63,24 +75,29 @@ RUN \
|
||||
|
||||
# Copy Python environment and packages
|
||||
ENV VIRTUAL_ENV=/app/api/.venv
|
||||
COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
|
||||
&& chmod -R 755 /usr/local/share/nltk_data
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \
|
||||
&& chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
|
||||
|
||||
# Copy source code
|
||||
COPY . /app/api/
|
||||
COPY --chown=dify:dify . /app/api/
|
||||
|
||||
# Prepare entrypoint script
|
||||
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
# Copy entrypoint
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA=${COMMIT_SHA}
|
||||
ENV NLTK_DATA=/usr/local/share/nltk_data
|
||||
|
||||
USER dify
|
||||
|
||||
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
|
||||
|
||||
@ -15,8 +15,8 @@
|
||||
```bash
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
# change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
|
||||
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
cd ../api
|
||||
```
|
||||
|
||||
@ -26,6 +26,10 @@
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
|
||||
1. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
@ -80,7 +84,7 @@
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
115
api/agent_skills/coding_style.md
Normal file
115
api/agent_skills/coding_style.md
Normal file
@ -0,0 +1,115 @@
|
||||
## Linter
|
||||
|
||||
- Always follow `.ruff.toml`.
|
||||
- Run `uv run ruff check --fix --unsafe-fixes`.
|
||||
- Keep each line under 100 characters (including spaces).
|
||||
|
||||
## Code Style
|
||||
|
||||
- `snake_case` for variables and functions.
|
||||
- `PascalCase` for classes.
|
||||
- `UPPER_CASE` for constants.
|
||||
|
||||
## Rules
|
||||
|
||||
- Use Pydantic v2 standard.
|
||||
- Use `uv` for package management.
|
||||
- Do not override dunder methods like `__init__`, `__iadd__`, etc.
|
||||
- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
|
||||
- Prefer simple functions over classes for lightweight helpers.
|
||||
- Keep files below 800 lines; split when necessary.
|
||||
- Keep code readable—no clever hacks.
|
||||
- Never use `print`; log with `logger = logging.getLogger(__name__)`.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
- Mirror the project’s layered architecture: controller → service → core/domain.
|
||||
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
|
||||
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
|
||||
|
||||
## SQLAlchemy Patterns
|
||||
|
||||
- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
|
||||
|
||||
- Open sessions with context managers:
|
||||
|
||||
```python
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
|
||||
|
||||
- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
|
||||
|
||||
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
|
||||
|
||||
## Storage & External IO
|
||||
|
||||
- Access storage via `extensions.ext_storage.storage`.
|
||||
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
|
||||
- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
|
||||
|
||||
## Pydantic Usage
|
||||
|
||||
- Define DTOs with Pydantic v2 models and forbid extras by default.
|
||||
|
||||
- Use `@field_validator` / `@model_validator` for domain rules.
|
||||
|
||||
- Example:
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
|
||||
|
||||
class TriggerConfig(BaseModel):
|
||||
endpoint: HttpUrl
|
||||
secret: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("secret")
|
||||
def ensure_secret_prefix(cls, value: str) -> str:
|
||||
if not value.startswith("dify_"):
|
||||
raise ValueError("secret must start with dify_")
|
||||
return value
|
||||
```
|
||||
|
||||
## Generics & Protocols
|
||||
|
||||
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
|
||||
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
|
||||
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
|
||||
|
||||
## Error Handling & Logging
|
||||
|
||||
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
|
||||
- Declare `logger = logging.getLogger(__name__)` at module top.
|
||||
- Include tenant/app/workflow identifiers in log context.
|
||||
- Log retryable events at `warning`, terminal failures at `error`.
|
||||
|
||||
## Tooling & Checks
|
||||
|
||||
- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
|
||||
- Type checks: `uv run --directory api --dev basedpyright`.
|
||||
- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||
- Run all of the above before submitting your work.
|
||||
|
||||
## Controllers & Services
|
||||
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
|
||||
- Document non-obvious behaviour with concise comments.
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
- Use `configs.dify_config` for configuration—never read environment variables directly.
|
||||
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
|
||||
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
|
||||
- Keep experimental scripts under `dev/`; do not ship them in production builds.
|
||||
96
api/agent_skills/infra.md
Normal file
96
api/agent_skills/infra.md
Normal file
@ -0,0 +1,96 @@
|
||||
## Configuration
|
||||
|
||||
- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
|
||||
- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
|
||||
- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
|
||||
- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
|
||||
|
||||
## Dependencies
|
||||
|
||||
- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
|
||||
- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
|
||||
- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
|
||||
|
||||
## Storage & Files
|
||||
|
||||
- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
|
||||
- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
|
||||
- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
|
||||
- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
|
||||
|
||||
## Redis & Shared State
|
||||
|
||||
- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
|
||||
- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
|
||||
|
||||
## Models
|
||||
|
||||
- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
|
||||
- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
|
||||
- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
|
||||
- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
|
||||
|
||||
## Vector Stores
|
||||
|
||||
- Vector client implementations live in `core/rag/datasource/vdb/<provider>`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
|
||||
- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
|
||||
- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
|
||||
- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
|
||||
|
||||
## Observability & OTEL
|
||||
|
||||
- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
|
||||
- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
|
||||
- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
|
||||
- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
|
||||
|
||||
## Ops Integrations
|
||||
|
||||
- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
|
||||
- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
|
||||
- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
|
||||
|
||||
## Controllers, Services, Core
|
||||
|
||||
- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
|
||||
- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
|
||||
- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
|
||||
|
||||
## Plugins, Tools, Providers
|
||||
|
||||
- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
|
||||
- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
|
||||
- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
|
||||
- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
|
||||
- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
|
||||
- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
|
||||
|
||||
## Async Workloads
|
||||
|
||||
see `agent_skills/trigger.md` for more detailed documentation.
|
||||
|
||||
- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
|
||||
- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
|
||||
- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
|
||||
- Generate migrations with `uv run --project api flask db revision --autogenerate -m "<summary>"`, then review the diff; never hand-edit the database outside Alembic.
|
||||
- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
|
||||
- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask <command>`.
|
||||
- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
|
||||
- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
|
||||
- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
|
||||
- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
|
||||
|
||||
## When You Add Features
|
||||
|
||||
- Check for an existing helper or service before writing a new util.
|
||||
- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
|
||||
- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
|
||||
- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.
|
||||
1
api/agent_skills/plugin.md
Normal file
1
api/agent_skills/plugin.md
Normal file
@ -0,0 +1 @@
|
||||
// TBD
|
||||
1
api/agent_skills/plugin_oauth.md
Normal file
1
api/agent_skills/plugin_oauth.md
Normal file
@ -0,0 +1 @@
|
||||
// TBD
|
||||
53
api/agent_skills/trigger.md
Normal file
53
api/agent_skills/trigger.md
Normal file
@ -0,0 +1,53 @@
|
||||
## Overview
|
||||
|
||||
Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
|
||||
|
||||
## Trigger nodes
|
||||
|
||||
- `UserInput`
|
||||
- `Trigger Webhook`
|
||||
- `Trigger Schedule`
|
||||
- `Trigger Plugin`
|
||||
|
||||
### UserInput
|
||||
|
||||
Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
|
||||
|
||||
1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
|
||||
1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
|
||||
1. For its detailed implementation, please refer to `core/workflow/nodes/start`
|
||||
|
||||
### Trigger Webhook
|
||||
|
||||
Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
|
||||
|
||||
Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
|
||||
|
||||
### Trigger Schedule
|
||||
|
||||
`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
|
||||
|
||||
To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
|
||||
|
||||
### Trigger Plugin
|
||||
|
||||
`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
|
||||
|
||||
1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
|
||||
1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
|
||||
|
||||
A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
|
||||
|
||||
## Worker Pool / Async Task
|
||||
|
||||
All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
|
||||
|
||||
The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
|
||||
|
||||
## Debug Strategy
|
||||
|
||||
Dify divided users into 2 groups: builders / end users.
|
||||
|
||||
Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
|
||||
|
||||
A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.
|
||||
21
api/app.py
21
api/app.py
@ -1,7 +1,7 @@
|
||||
import sys
|
||||
|
||||
|
||||
def is_db_command():
|
||||
def is_db_command() -> bool:
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
@ -13,23 +13,12 @@ if is_db_command():
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
"""
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
|
||||
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
@ -50,6 +51,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
ext_logging,
|
||||
@ -74,6 +76,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_warnings,
|
||||
ext_import_modules,
|
||||
ext_orjson,
|
||||
ext_forward_refs,
|
||||
ext_set_secretkey,
|
||||
ext_compress,
|
||||
ext_code_based_extension,
|
||||
|
||||
7
api/cnt_base.sh
Executable file
7
api/cnt_base.sh
Executable file
@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
for pattern in "Base" "TypeBase"; do
|
||||
printf "%s " "$pattern"
|
||||
grep "($pattern):" -r --include='*.py' --exclude-dir=".venv" --exclude-dir="tests" . | wc -l
|
||||
done
|
||||
@ -15,12 +15,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -321,6 +321,8 @@ def migrate_knowledge_vector_database():
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
if not datasets.items:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
@ -1227,6 +1229,55 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
@ -1420,7 +1471,10 @@ def setup_datasource_oauth_client(provider, client_params):
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
def transform_datasource_credentials():
|
||||
@click.option(
|
||||
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
|
||||
)
|
||||
def transform_datasource_credentials(environment: str):
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
@ -1431,9 +1485,14 @@ def transform_datasource_credentials():
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
jina_plugin_unique_identifier = None
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
@ -1521,6 +1580,14 @@ def transform_datasource_credentials():
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not firecrawl_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
@ -1576,6 +1643,14 @@ def transform_datasource_credentials():
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not jina_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
@ -1583,7 +1658,7 @@ def transform_datasource_credentials():
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
provider="jinareader",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
|
||||
@ -73,14 +73,14 @@ class AppExecutionConfig(BaseSettings):
|
||||
description="Maximum allowed execution time for the application in seconds",
|
||||
default=1200,
|
||||
)
|
||||
APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description="Default number of concurrent active requests per app (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description="Maximum number of concurrent active requests per app (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
APP_DAILY_RATE_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum number of requests per app per day",
|
||||
default=5000,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
@ -150,7 +150,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
|
||||
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
|
||||
description="Maximum allowed length for strings in code execution",
|
||||
default=80000,
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
|
||||
@ -174,6 +174,33 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
)
|
||||
|
||||
|
||||
class AsyncWorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for async workflow
|
||||
"""
|
||||
|
||||
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
|
||||
description="Granularity for async workflow scheduler, "
|
||||
"sometime, few users could block the queue due to some time-consuming tasks, "
|
||||
"to avoid this, workflow can be suspended if needed, to achieve"
|
||||
"this, a time-based checker is required, every granularity seconds, "
|
||||
"the checker will check the workflow queue and suspend the workflow",
|
||||
default=120,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
@ -189,6 +216,11 @@ class PluginConfig(BaseSettings):
|
||||
default="plugin-api-key",
|
||||
)
|
||||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=300.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||
@ -258,6 +290,8 @@ class EndpointConfig(BaseSettings):
|
||||
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
|
||||
)
|
||||
|
||||
TRIGGER_URL: str = Field(description="Template url for triggers", default="http://localhost:5001")
|
||||
|
||||
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
@ -326,12 +360,42 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
|
||||
"Empty by default to allow all file types."
|
||||
),
|
||||
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
|
||||
"""
|
||||
Parse and return the blacklist as a set of lowercase extensions.
|
||||
Returns an empty set if no blacklist is configured.
|
||||
"""
|
||||
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
|
||||
return set()
|
||||
return {
|
||||
ext.strip().lower().strip(".")
|
||||
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
|
||||
if ext.strip()
|
||||
}
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
HTTP-related configurations for the application
|
||||
"""
|
||||
|
||||
COOKIE_DOMAIN: str = Field(
|
||||
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
|
||||
default="",
|
||||
)
|
||||
|
||||
API_COMPRESSION_ENABLED: bool = Field(
|
||||
description="Enable or disable gzip compression for HTTP responses",
|
||||
default=False,
|
||||
@ -362,11 +426,11 @@ class HttpConfig(BaseSettings):
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
@ -543,7 +607,7 @@ class UpdateConfig(BaseSettings):
|
||||
|
||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||
# 100KB
|
||||
# 1000 KiB
|
||||
1024_000,
|
||||
description="Maximum size for variable to trigger final truncation.",
|
||||
)
|
||||
@ -582,6 +646,11 @@ class WorkflowConfig(BaseSettings):
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
TEMPLATE_TRANSFORM_MAX_LENGTH: PositiveInt = Field(
|
||||
description="Maximum number of characters allowed in Template Transform node output",
|
||||
default=400_000,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
@ -766,7 +835,7 @@ class MailConfig(BaseSettings):
|
||||
|
||||
MAIL_TEMPLATING_TIMEOUT: int = Field(
|
||||
description="""
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Only available in sandbox mode.""",
|
||||
default=3,
|
||||
)
|
||||
@ -905,6 +974,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
|
||||
description="Maximum number of segments for dataset segments API (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
@ -980,6 +1054,44 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
default=True,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_INTERVAL: int = Field(
|
||||
description="Trigger provider refresh poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_BATCH_SIZE: int = Field(
|
||||
description="Max trigger subscriptions to process per tick",
|
||||
default=200,
|
||||
)
|
||||
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive credential refresh threshold in seconds",
|
||||
default=60 * 60,
|
||||
)
|
||||
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive subscription refresh threshold in seconds",
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@ -1078,7 +1190,7 @@ class AccountConfig(BaseSettings):
|
||||
|
||||
|
||||
class WorkflowLogConfig(BaseSettings):
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
|
||||
default=100, description="Batch size for workflow run log cleanup operations"
|
||||
@ -1097,12 +1209,21 @@ class SwaggerUIConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
@ -1121,6 +1242,7 @@ class FeatureConfig(
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
|
||||
@ -18,6 +18,7 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
|
||||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
@ -104,6 +105,12 @@ class KeywordStoreConfig(BaseSettings):
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
# Database type selector
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
|
||||
description="Database type to use. OceanBase is MySQL-compatible.",
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
DB_HOST: str = Field(
|
||||
description="Hostname or IP address of the database server.",
|
||||
default="localhost",
|
||||
@ -139,12 +146,12 @@ class DatabaseConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description="Database URI scheme for SQLAlchemy connection.",
|
||||
default="postgresql",
|
||||
)
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
|
||||
return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql"
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
@ -197,21 +204,21 @@ class DatabaseConfig(BaseSettings):
|
||||
default=os.cpu_count() or 1,
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
|
||||
options = db_extras_dict.get("options", "")
|
||||
# Always include timezone
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
# Merge user options and timezone
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
|
||||
connect_args = {"options": merged_options}
|
||||
connect_args = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
return {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
@ -330,6 +337,7 @@ class MiddlewareConfig(
|
||||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
MilvusConfig,
|
||||
AlibabaCloudMySQLConfig,
|
||||
MyScaleConfig,
|
||||
OpenSearchConfig,
|
||||
OracleConfig,
|
||||
|
||||
54
api/configs/middleware/vdb/alibabacloud_mysql_config.py
Normal file
54
api/configs/middleware/vdb/alibabacloud_mysql_config.py
Normal file
@ -0,0 +1,54 @@
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AlibabaCloudMySQLConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for AlibabaCloud MySQL vector database
|
||||
"""
|
||||
|
||||
ALIBABACLOUD_MYSQL_HOST: str = Field(
|
||||
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
|
||||
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
|
||||
default=3306,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_USER: str = Field(
|
||||
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
|
||||
default="root",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
|
||||
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
|
||||
default="",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
|
||||
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Maximum number of connections in the connection pool",
|
||||
default=5,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
|
||||
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
|
||||
default="utf8mb4",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
|
||||
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
|
||||
"(e.g., 'cosine', 'euclidean')",
|
||||
default="cosine",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
|
||||
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
|
||||
default=6,
|
||||
)
|
||||
@ -1,23 +1,24 @@
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AuthMethod(StrEnum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for OpenSearch
|
||||
"""
|
||||
|
||||
class AuthMethod(Enum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
OPENSEARCH_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
||||
default=None,
|
||||
|
||||
@ -22,7 +22,17 @@ class WeaviateConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
)
|
||||
|
||||
WEAVIATE_TOKENIZATION: str | None = Field(
|
||||
description="Tokenization for Weaviate (default is word)",
|
||||
default="word",
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from configs import dify_config
|
||||
from libs.collection_utils import convert_to_lower_and_upper_set
|
||||
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UNKNOWN_VALUE = "[__UNKNOWN__]"
|
||||
@ -6,24 +7,39 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "webm"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "amr", "mpga"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
|
||||
|
||||
|
||||
_doc_extensions: list[str]
|
||||
_doc_extensions: set[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
_doc_extensions = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
|
||||
_doc_extensions.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
"mdx",
|
||||
"pdf",
|
||||
"html",
|
||||
"htm",
|
||||
"xlsx",
|
||||
"xls",
|
||||
"vtt",
|
||||
"properties",
|
||||
"doc",
|
||||
"docx",
|
||||
"csv",
|
||||
"eml",
|
||||
"msg",
|
||||
"pptx",
|
||||
"xml",
|
||||
"epub",
|
||||
}
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
_doc_extensions.append("ppt")
|
||||
_doc_extensions.add("ppt")
|
||||
else:
|
||||
_doc_extensions = [
|
||||
_doc_extensions = {
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
@ -37,5 +53,18 @@ else:
|
||||
"csv",
|
||||
"vtt",
|
||||
"properties",
|
||||
]
|
||||
DOCUMENT_EXTENSIONS = _doc_extensions + [ext.upper() for ext in _doc_extensions]
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
||||
# console
|
||||
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
||||
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
||||
|
||||
# webapp
|
||||
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
|
||||
COOKIE_NAME_PASSPORT = "passport"
|
||||
|
||||
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
||||
HEADER_NAME_APP_CODE = "X-App-Code"
|
||||
HEADER_NAME_PASSPORT = "X-App-Passport"
|
||||
|
||||
@ -31,3 +31,9 @@ def supported_language(lang):
|
||||
|
||||
error = f"{lang} is not a valid language."
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_valid_language(lang: str | None) -> str:
|
||||
if lang and lang in languages:
|
||||
return lang
|
||||
return languages[0]
|
||||
|
||||
7343
api/constants/pipeline_templates.json
Normal file
7343
api/constants/pipeline_templates.json
Normal file
File diff suppressed because one or more lines are too long
@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
|
||||
"""
|
||||
@ -41,3 +42,11 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
||||
code = 415
|
||||
|
||||
|
||||
class BlockedFileExtensionError(BaseHTTPException):
|
||||
error_code = "file_extension_blocked"
|
||||
description = "The file extension is blocked for security reasons."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
|
||||
@ -24,7 +24,7 @@ except ImportError:
|
||||
)
|
||||
else:
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||
magic = None # type: ignore
|
||||
magic = None # type: ignore[assignment]
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -66,6 +66,7 @@ from .app import (
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
@ -126,6 +127,7 @@ from .workspace import (
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
@ -196,6 +198,7 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
@ -203,5 +206,6 @@ __all__ = [
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
@ -12,9 +12,10 @@ P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
auth_token = extract_access_token(request)
|
||||
if not auth_token:
|
||||
raise Unauthorized("Authorization header is missing.")
|
||||
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
@ -47,10 +38,10 @@ def admin_required(view: Callable[P, R]):
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps")
|
||||
class InsertExploreAppListApi(Resource):
|
||||
@api.doc("insert_explore_app")
|
||||
@api.doc(description="Insert or update an app in the explore list")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("insert_explore_app")
|
||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"InsertExploreAppRequest",
|
||||
{
|
||||
"app_id": fields.String(required=True, description="Application ID"),
|
||||
@ -64,21 +55,23 @@ class InsertExploreAppListApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "App updated successfully")
|
||||
@api.response(201, "App inserted successfully")
|
||||
@api.response(404, "App not found")
|
||||
@console_ns.response(200, "App updated successfully")
|
||||
@console_ns.response(201, "App inserted successfully")
|
||||
@console_ns.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("desc", type=str, location="json")
|
||||
parser.add_argument("copyright", type=str, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, location="json")
|
||||
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("desc", type=str, location="json")
|
||||
.add_argument("copyright", type=str, location="json")
|
||||
.add_argument("privacy_policy", type=str, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, location="json")
|
||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||
@ -138,10 +131,10 @@ class InsertExploreAppListApi(Resource):
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
|
||||
class InsertExploreAppApi(Resource):
|
||||
@api.doc("delete_explore_app")
|
||||
@api.doc(description="Remove an app from the explore list")
|
||||
@api.doc(params={"app_id": "Application ID to remove"})
|
||||
@api.response(204, "App removed successfully")
|
||||
@console_ns.doc("delete_explore_app")
|
||||
@console_ns.doc(description="Remove an app from the explore list")
|
||||
@console_ns.doc(params={"app_id": "Application ID to remove"})
|
||||
@console_ns.response(204, "App removed successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import flask_restx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
@ -8,12 +7,12 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
from . import api, console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
@ -25,6 +24,12 @@ api_key_fields = {
|
||||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
|
||||
|
||||
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
|
||||
|
||||
api_key_list_model = console_ns.model(
|
||||
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
if resource_model == App:
|
||||
@ -53,11 +58,13 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
@marshal_with(api_key_list)
|
||||
@marshal_with(api_key_list_model)
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
@ -65,14 +72,13 @@ class BaseApiKeyListResource(Resource):
|
||||
).all()
|
||||
return {"items": keys}
|
||||
|
||||
@marshal_with(api_key_fields)
|
||||
@marshal_with(api_key_item_model)
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
@ -89,7 +95,7 @@ class BaseApiKeyListResource(Resource):
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.tenant_id = current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
@ -104,13 +110,11 @@ class BaseApiKeyResource(Resource):
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(self, resource_id, api_key_id):
|
||||
def delete(self, resource_id: str, api_key_id: str):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
@ -135,28 +139,23 @@ class BaseApiKeyResource(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@api.doc("get_app_api_keys")
|
||||
@api.doc(description="Get all API keys for an app")
|
||||
@api.doc(params={"resource_id": "App ID"})
|
||||
@api.response(200, "Success", api_key_list)
|
||||
def get(self, resource_id):
|
||||
@console_ns.doc("get_app_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@api.doc("create_app_api_key")
|
||||
@api.doc(description="Create a new API key for an app")
|
||||
@api.doc(params={"resource_id": "App ID"})
|
||||
@api.response(201, "API key created successfully", api_key_fields)
|
||||
@api.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id):
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
@ -165,19 +164,14 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
class AppApiKeyResource(BaseApiKeyResource):
|
||||
@api.doc("delete_app_api_key")
|
||||
@api.doc(description="Delete an API key for an app")
|
||||
@api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@api.response(204, "API key deleted successfully")
|
||||
@console_ns.doc("delete_app_api_key")
|
||||
@console_ns.doc(description="Delete an API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
@ -185,28 +179,23 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys")
|
||||
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@api.doc("get_dataset_api_keys")
|
||||
@api.doc(description="Get all API keys for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID"})
|
||||
@api.response(200, "Success", api_key_list)
|
||||
def get(self, resource_id):
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@api.doc("create_dataset_api_key")
|
||||
@api.doc(description="Create a new API key for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID"})
|
||||
@api.response(201, "API key created successfully", api_key_fields)
|
||||
@api.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id):
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
@ -215,19 +204,14 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@api.doc("delete_dataset_api_key")
|
||||
@api.doc(description="Delete an API key for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@api.response(204, "API key deleted successfully")
|
||||
@console_ns.doc("delete_dataset_api_key")
|
||||
@console_ns.doc(description="Delete an API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
|
||||
@ -1,35 +1,39 @@
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
app_mode: str = Field(..., description="Application mode")
|
||||
model_mode: str = Field(..., description="Model mode")
|
||||
has_context: str = Field(default="true", description="Whether has context")
|
||||
model_name: str = Field(..., description="Model name")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AdvancedPromptTemplateQuery.__name__,
|
||||
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/app/prompt-templates")
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
@api.doc("get_advanced_prompt_templates")
|
||||
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("get_advanced_prompt_templates")
|
||||
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
|
||||
@console_ns.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("app_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("model_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||
parser.add_argument("model_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.helper import uuid_value
|
||||
@ -8,29 +8,29 @@ from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
|
||||
class AgentLogApi(Resource):
|
||||
@api.doc("get_agent_logs")
|
||||
@api.doc(description="Get agent execution logs for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
|
||||
@console_ns.doc("get_agent_logs")
|
||||
@console_ns.doc(description="Get agent execution logs for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
||||
)
|
||||
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
|
||||
@ -1,33 +1,34 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
build_annotation_model,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@api.doc("annotation_reply_action")
|
||||
@api.doc(description="Enable or disable annotation reply for an app")
|
||||
@api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("annotation_reply_action")
|
||||
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationReplyActionRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
|
||||
@ -36,21 +37,21 @@ class AnnotationReplyActionApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Action completed successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Action completed successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
@ -61,18 +62,16 @@ class AnnotationReplyActionApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-setting")
|
||||
class AppAnnotationSettingDetailApi(Resource):
|
||||
@api.doc("get_annotation_setting")
|
||||
@api.doc(description="Get annotation settings for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Annotation settings retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("get_annotation_setting")
|
||||
@console_ns.doc(description="Get annotation settings for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Annotation settings retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||
return result, 200
|
||||
@ -80,11 +79,11 @@ class AppAnnotationSettingDetailApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
|
||||
class AppAnnotationSettingUpdateApi(Resource):
|
||||
@api.doc("update_annotation_setting")
|
||||
@api.doc(description="Update annotation settings for an app")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_annotation_setting")
|
||||
@console_ns.doc(description="Update annotation settings for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AnnotationSettingUpdateRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold"),
|
||||
@ -93,20 +92,17 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Settings updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Settings updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
@ -115,19 +111,17 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>")
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@api.doc("get_annotation_reply_action_status")
|
||||
@api.doc(description="Get status of annotation reply action job")
|
||||
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
|
||||
@api.response(200, "Job status retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("get_annotation_reply_action_status")
|
||||
@console_ns.doc(description="Get status of annotation reply action job")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
|
||||
@console_ns.response(200, "Job status retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id, action):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
@ -145,24 +139,22 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class AnnotationApi(Resource):
|
||||
@api.doc("list_annotations")
|
||||
@api.doc(description="Get annotations for an app with pagination")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_annotations")
|
||||
@console_ns.doc(description="Get annotations for an app with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
|
||||
)
|
||||
@api.response(200, "Annotations retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(200, "Annotations retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
@ -178,45 +170,48 @@ class AnnotationApi(Resource):
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@api.doc("create_annotation")
|
||||
@api.doc(description="Create a new annotation for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# Use request.args.getlist to get annotation_ids array directly
|
||||
@ -241,46 +236,51 @@ class AnnotationApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
||||
class AnnotationExportApi(Resource):
|
||||
@api.doc("export_annotations")
|
||||
@api.doc(description="Export all annotations for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields)))
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("export_annotations")
|
||||
@console_ns.doc(description="Export all annotations for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotations exported successfully",
|
||||
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response = {"data": marshal(annotation_list, annotation_fields)}
|
||||
return response, 200
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@api.doc("update_delete_annotation")
|
||||
@api.doc(description="Update or delete an annotation")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@api.response(200, "Annotation updated successfully", annotation_fields)
|
||||
@api.response(204, "Annotation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("update_delete_annotation")
|
||||
@console_ns.doc(description="Update or delete an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
@ -288,10 +288,8 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
@ -300,20 +298,18 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@api.doc("batch_import_annotations")
|
||||
@api.doc(description="Batch import annotations from CSV file")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Batch import started successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "No file uploaded or too many files")
|
||||
@console_ns.doc("batch_import_annotations")
|
||||
@console_ns.doc(description="Batch import annotations from CSV file")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Batch import started successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "No file uploaded or too many files")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
@ -332,19 +328,17 @@ class AnnotationBatchImportApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||
class AnnotationBatchImportStatusApi(Resource):
|
||||
@api.doc("get_batch_import_status")
|
||||
@api.doc(description="Get status of batch import job")
|
||||
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
|
||||
@api.response(200, "Job status retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("get_batch_import_status")
|
||||
@console_ns.doc(description="Get status of batch import job")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
|
||||
@console_ns.response(200, "Job status retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
@ -361,25 +355,32 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
@api.doc("list_annotation_hit_histories")
|
||||
@api.doc(description="Get hit histories for an annotation")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
@console_ns.doc("list_annotation_hit_histories")
|
||||
@console_ns.doc(description="Get hit histories for an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
)
|
||||
@api.response(
|
||||
200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields))
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit histories retrieved successfully",
|
||||
console_ns.model(
|
||||
"AnnotationHitHistoryList",
|
||||
{
|
||||
"data": fields.List(
|
||||
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
|
||||
@ -1,101 +1,295 @@
|
||||
import uuid
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from fields.app_fields import (
|
||||
deleted_tool_fields,
|
||||
model_config_fields,
|
||||
model_config_partial_fields,
|
||||
site_fields,
|
||||
tag_fields,
|
||||
)
|
||||
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
|
||||
from libs.helper import AppIconUrlField, TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import App, Workflow
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if description and len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name to check")
|
||||
|
||||
|
||||
class AppIconPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon data")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
|
||||
class AppSiteStatusPayload(BaseModel):
|
||||
enable_site: bool = Field(..., description="Enable or disable site")
|
||||
|
||||
|
||||
class AppApiStatusPayload(BaseModel):
|
||||
enable_api: bool = Field(..., description="Enable or disable API")
|
||||
|
||||
|
||||
class AppTracePayload(BaseModel):
|
||||
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||
tracing_provider: str = Field(..., description="Tracing provider")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AppListQuery)
|
||||
reg(CreateAppPayload)
|
||||
reg(UpdateAppPayload)
|
||||
reg(CopyAppPayload)
|
||||
reg(AppExportQuery)
|
||||
reg(AppNamePayload)
|
||||
reg(AppIconPayload)
|
||||
reg(AppSiteStatusPayload)
|
||||
reg(AppApiStatusPayload)
|
||||
reg(AppTracePayload)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base models first
|
||||
tag_model = console_ns.model("Tag", tag_fields)
|
||||
|
||||
workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
|
||||
|
||||
model_config_model = console_ns.model("ModelConfig", model_config_fields)
|
||||
|
||||
model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
|
||||
|
||||
deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
|
||||
|
||||
site_model = console_ns.model("Site", site_fields)
|
||||
|
||||
app_partial_model = console_ns.model(
|
||||
"AppPartial",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"max_active_requests": fields.Raw(),
|
||||
"description": fields.String(attribute="desc_or_prompt"),
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"access_mode": fields.String,
|
||||
"create_user_name": fields.String,
|
||||
"author_name": fields.String,
|
||||
"has_draft_trigger": fields.Boolean,
|
||||
},
|
||||
)
|
||||
|
||||
app_detail_model = console_ns.model(
|
||||
"AppDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"tracing": fields.Raw,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
},
|
||||
)
|
||||
|
||||
app_detail_with_site_model = console_ns.model(
|
||||
"AppDetailWithSite",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"api_base_url": fields.String,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"max_active_requests": fields.Integer,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"site": fields.Nested(site_model),
|
||||
},
|
||||
)
|
||||
|
||||
app_pagination_model = console_ns.model(
|
||||
"AppPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(app_partial_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@api.doc("list_apps")
|
||||
@api.doc(description="Get list of applications with pagination and filtering")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
|
||||
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
|
||||
default="all",
|
||||
help="App mode filter",
|
||||
)
|
||||
.add_argument("name", type=str, location="args", help="Filter by app name")
|
||||
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
|
||||
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
|
||||
)
|
||||
@api.response(200, "Success", app_pagination_fields)
|
||||
@console_ns.doc("list_apps")
|
||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||
@console_ns.expect(console_ns.models[AppListQuery.__name__])
|
||||
@console_ns.response(200, "Success", app_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(",")]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"completion",
|
||||
"chat",
|
||||
"advanced-chat",
|
||||
"workflow",
|
||||
"agent-chat",
|
||||
"channel",
|
||||
"all",
|
||||
],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument("name", type=str, location="args", required=False)
|
||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_dict = args.model_dump()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
@ -109,71 +303,72 @@ class AppListApi(Resource):
|
||||
if str(app.id) in res:
|
||||
app.access_mode = res[str(app.id)].access_mode
|
||||
|
||||
return marshal(app_pagination, app_pagination_fields), 200
|
||||
workflow_capable_app_ids = [
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
trigger_node_types = {
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
|
||||
@api.doc("create_app")
|
||||
@api.doc(description="Create a new application")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "App created successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
return marshal(app_pagination, app_pagination_model), 200
|
||||
|
||||
@console_ns.doc("create_app")
|
||||
@console_ns.doc(description="Create a new application")
|
||||
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
|
||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Create app"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if "mode" not in args or args["mode"] is None:
|
||||
raise BadRequest("mode is required")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
if current_user.current_tenant_id is None:
|
||||
raise ValueError("current_user.current_tenant_id cannot be None")
|
||||
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>")
|
||||
class AppApi(Resource):
|
||||
@api.doc("get_app_detail")
|
||||
@api.doc(description="Get application details")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Success", app_detail_fields_with_site)
|
||||
@console_ns.doc("get_app_detail")
|
||||
@console_ns.doc(description="Get application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Success", app_detail_with_site_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
app_service = AppService()
|
||||
@ -186,79 +381,50 @@ class AppApi(Resource):
|
||||
|
||||
return app_model
|
||||
|
||||
@api.doc("update_app")
|
||||
@api.doc(description="Update application details")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
"max_active_requests": fields.Integer(description="Maximum active requests"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "App updated successfully", app_detail_fields_with_site)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.doc("update_app")
|
||||
@console_ns.doc(description="Update application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
args = parser.parse_args()
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
# Construct ArgsDict from parsed arguments
|
||||
from services.app_service import AppService as AppServiceType
|
||||
|
||||
args_dict: AppServiceType.ArgsDict = {
|
||||
"name": args["name"],
|
||||
"description": args.get("description", ""),
|
||||
"icon_type": args.get("icon_type", ""),
|
||||
"icon": args.get("icon", ""),
|
||||
"icon_background": args.get("icon_background", ""),
|
||||
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
||||
"max_active_requests": args.get("max_active_requests", 0),
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
"max_active_requests": args.max_active_requests or 0,
|
||||
}
|
||||
app_model = app_service.update_app(app_model, args_dict)
|
||||
|
||||
return app_model
|
||||
|
||||
@api.doc("delete_app")
|
||||
@api.doc(description="Delete application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(204, "App deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("delete_app")
|
||||
@console_ns.doc(description="Delete application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(204, "App deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model):
|
||||
"""Delete app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
|
||||
@ -267,55 +433,37 @@ class AppApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/copy")
|
||||
class AppCopyApi(Resource):
|
||||
@api.doc("copy_app")
|
||||
@api.doc(description="Create a copy of an existing application")
|
||||
@api.doc(params={"app_id": "Application ID to copy"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CopyAppRequest",
|
||||
{
|
||||
"name": fields.String(description="Name for the copied app"),
|
||||
"description": fields.String(description="Description for the copied app"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "App copied successfully", app_detail_fields_with_site)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("copy_app")
|
||||
@console_ns.doc(description="Create a copy of an existing application")
|
||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
def post(self, app_model):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=_validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@ -327,178 +475,131 @@ class AppCopyApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/export")
|
||||
class AppExportApi(Resource):
|
||||
@api.doc("export_app")
|
||||
@api.doc(description="Export application configuration as DSL")
|
||||
@api.doc(params={"app_id": "Application ID to export"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
|
||||
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("export_app")
|
||||
@console_ns.doc(description="Export application configuration as DSL")
|
||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
|
||||
console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
|
||||
)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
parser.add_argument("workflow_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
app_model=app_model,
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@api.doc("check_app_name")
|
||||
@api.doc(description="Check if app name is available")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
|
||||
@api.response(200, "Name availability checked")
|
||||
@console_ns.doc("check_app_name")
|
||||
@console_ns.doc(description="Check if app name is available")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
|
||||
@console_ns.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args["name"])
|
||||
app_model = app_service.update_app_name(app_model, args.name)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/icon")
|
||||
class AppIconApi(Resource):
|
||||
@api.doc("update_app_icon")
|
||||
@api.doc(description="Update application icon")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppIconRequest",
|
||||
{
|
||||
"icon": fields.String(required=True, description="Icon data"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Icon updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("update_app_icon")
|
||||
@console_ns.doc(description="Update application icon")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
|
||||
@console_ns.response(200, "Icon updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
||||
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site-enable")
|
||||
class AppSiteStatus(Resource):
|
||||
@api.doc("update_app_site_status")
|
||||
@api.doc(description="Enable or disable app site")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Site status updated successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("update_app_site_status")
|
||||
@console_ns.doc(description="Enable or disable app site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
|
||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enable_site", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
||||
app_model = app_service.update_app_site_status(app_model, args.enable_site)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/api-enable")
|
||||
class AppApiStatus(Resource):
|
||||
@api.doc("update_app_api_status")
|
||||
@api.doc(description="Enable or disable app API")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
|
||||
)
|
||||
)
|
||||
@api.response(200, "API status updated successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("update_app_api_status")
|
||||
@console_ns.doc(description="Enable or disable app API")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
|
||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields)
|
||||
@marshal_with(app_detail_model)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
||||
app_model = app_service.update_app_api_status(app_model, args.enable_api)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace")
|
||||
class AppTraceApi(Resource):
|
||||
@api.doc("get_app_trace")
|
||||
@api.doc(description="Get app tracing configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Trace configuration retrieved successfully")
|
||||
@console_ns.doc("get_app_trace")
|
||||
@console_ns.doc(description="Get app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Trace configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -508,36 +609,24 @@ class AppTraceApi(Resource):
|
||||
|
||||
return app_trace_config
|
||||
|
||||
@api.doc("update_app_trace")
|
||||
@api.doc(description="Update app tracing configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppTraceRequest",
|
||||
{
|
||||
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trace configuration updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("update_app_trace")
|
||||
@console_ns.doc(description="Update app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
|
||||
@console_ns.response(200, "Trace configuration updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
# add app trace
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enabled", type=bool, required=True, location="json")
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
enabled=args["enabled"],
|
||||
tracing_provider=args["tracing_provider"],
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from fields.app_fields import (
|
||||
app_import_check_dependencies_fields,
|
||||
app_import_fields,
|
||||
leaked_dependency_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -22,36 +22,52 @@ from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
|
||||
|
||||
app_import_model = console_ns.model("AppImport", app_import_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
|
||||
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
|
||||
app_import_check_dependencies_model = console_ns.model(
|
||||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
@marshal_with(app_import_model)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("app_id", type=str, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
account = current_user
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
@ -70,9 +86,9 @@ class AppImportApi(Resource):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
if status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
elif status == ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@ -82,22 +98,22 @@ class AppImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_fields)
|
||||
@marshal_with(app_import_model)
|
||||
@edit_permission_required
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@ -108,11 +124,9 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_fields)
|
||||
@marshal_with(app_import_check_dependencies_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
@ -5,7 +5,7 @@ from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@ -36,16 +36,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
class ChatMessageAudioApi(Resource):
|
||||
@api.doc("chat_message_audio_transcript")
|
||||
@api.doc(description="Transcript audio to text for chat messages")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.response(
|
||||
@console_ns.doc("chat_message_audio_transcript")
|
||||
@console_ns.doc(description="Transcript audio to text for chat messages")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Audio transcription successful",
|
||||
api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
)
|
||||
@api.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@api.response(413, "Audio file too large")
|
||||
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@console_ns.response(413, "Audio file too large")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -89,11 +89,11 @@ class ChatMessageAudioApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/text-to-audio")
|
||||
class ChatMessageTextApi(Resource):
|
||||
@api.doc("chat_message_text_to_speech")
|
||||
@api.doc(description="Convert text to speech for chat messages")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("chat_message_text_to_speech")
|
||||
@console_ns.doc(description="Convert text to speech for chat messages")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TextToSpeechRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
@ -103,19 +103,21 @@ class ChatMessageTextApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Text to speech conversion successful")
|
||||
@api.response(400, "Bad request - Invalid parameters")
|
||||
@console_ns.response(200, "Text to speech conversion successful")
|
||||
@console_ns.response(400, "Bad request - Invalid parameters")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
@ -154,20 +156,23 @@ class ChatMessageTextApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices")
|
||||
class TextModesApi(Resource):
|
||||
@api.doc("get_text_to_speech_voices")
|
||||
@api.doc(description="Get available TTS voices for a specific language")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code"))
|
||||
@api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")))
|
||||
@api.response(400, "Invalid language parameter")
|
||||
@console_ns.doc("get_text_to_speech_voices")
|
||||
@console_ns.doc(description="Get available TTS voices for a specific language")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
|
||||
)
|
||||
@console_ns.response(
|
||||
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
||||
)
|
||||
@console_ns.response(400, "Invalid language parameter")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
try:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("language", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -15,9 +17,8 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
@ -32,48 +33,66 @@ from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(default="", description="Query text")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(..., description="User query")
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionMessagePayload.__name__,
|
||||
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-messages")
|
||||
class CompletionMessageApi(Resource):
|
||||
@api.doc("create_completion_message")
|
||||
@api.doc(description="Generate completion message for debugging")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CompletionMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(description="Query text", default=""),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Completion generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(404, "App not found")
|
||||
@console_ns.doc("create_completion_message")
|
||||
@console_ns.doc(description="Generate completion message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
@console_ns.response(200, "Completion generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
args = parser.parse_args()
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
@ -108,10 +127,10 @@ class CompletionMessageApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
|
||||
class CompletionMessageStopApi(Resource):
|
||||
@api.doc("stop_completion_message")
|
||||
@api.doc(description="Stop a running completion message generation")
|
||||
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@api.response(200, "Task stopped successfully")
|
||||
@console_ns.doc("stop_completion_message")
|
||||
@console_ns.doc(description="Stop a running completion message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -119,57 +138,36 @@ class CompletionMessageStopApi(Resource):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
class ChatMessageApi(Resource):
|
||||
@api.doc("create_chat_message")
|
||||
@api.doc(description="Generate chat message for debugging")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ChatMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
"parent_message_id": fields.String(description="Parent message ID"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Chat message generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(404, "App or conversation not found")
|
||||
@console_ns.doc("create_chat_message")
|
||||
@console_ns.doc(description="Generate chat message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
@console_ns.response(200, "Chat message generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App or conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
@ -210,10 +208,10 @@ class ChatMessageApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||
class ChatMessageStopApi(Resource):
|
||||
@api.doc("stop_chat_message")
|
||||
@api.doc(description="Stop a running chat message generation")
|
||||
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@api.response(200, "Task stopped successfully")
|
||||
@console_ns.doc("stop_chat_message")
|
||||
@console_ns.doc(description="Stop a running chat message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -221,6 +219,12 @@ class ChatMessageStopApi(Resource):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -1,116 +1,377 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
import pytz # pip install pytz
|
||||
import sqlalchemy as sa
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
conversation_detail_fields,
|
||||
conversation_message_detail_fields,
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models import Account, Conversation, EndUser, Message, MessageAnnotation
|
||||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseConversationQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
|
||||
default="all", description="Annotation status filter"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class CompletionConversationQuery(BaseConversationQuery):
|
||||
pass
|
||||
|
||||
|
||||
class ChatConversationQuery(BaseConversationQuery):
|
||||
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort field and direction"
|
||||
)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionConversationQuery.__name__,
|
||||
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatConversationQuery.__name__,
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
feedback_stat_model = console_ns.model(
|
||||
"FeedbackStat",
|
||||
{
|
||||
"like": fields.Integer,
|
||||
"dislike": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
status_count_model = console_ns.model(
|
||||
"StatusCount",
|
||||
{
|
||||
"success": fields.Integer,
|
||||
"failed": fields.Integer,
|
||||
"partial_success": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
simple_model_config_model = console_ns.model(
|
||||
"SimpleModelConfig",
|
||||
{
|
||||
"model": fields.Raw(attribute="model_dict"),
|
||||
"pre_prompt": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
model_config_model = console_ns.model(
|
||||
"ModelConfig",
|
||||
{
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"model": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"pre_prompt": fields.String,
|
||||
"agent_mode": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
# Simple message detail model
|
||||
simple_message_detail_model = console_ns.model(
|
||||
"SimpleMessageDetail",
|
||||
{
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": MessageTextField,
|
||||
"answer": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Conversation models
|
||||
conversation_fields_model = console_ns.model(
|
||||
"Conversation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String(),
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"message": fields.Nested(simple_message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_pagination_model = console_ns.model(
|
||||
"ConversationPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_fields_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_message_detail_model = console_ns.model(
|
||||
"ConversationMessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message": fields.Nested(message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_model = console_ns.model(
|
||||
"ConversationWithSummary",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"name": fields.String,
|
||||
"summary": fields.String(attribute="summary_or_query"),
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"status_count": fields.Nested(status_count_model),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_pagination_model = console_ns.model(
|
||||
"ConversationWithSummaryPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_detail_model = console_ns.model(
|
||||
"ConversationDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"introduction": fields.String,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations")
|
||||
class CompletionConversationApi(Resource):
|
||||
@api.doc("list_completion_conversations")
|
||||
@api.doc(description="Get completion conversations with pagination and filtering")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
)
|
||||
@api.response(200, "Success", conversation_pagination_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("list_completion_conversations")
|
||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_pagination_fields)
|
||||
@marshal_with(conversation_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
if args["keyword"]:
|
||||
if args.keyword:
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
or_(
|
||||
Message.query.ilike(f"%{args['keyword']}%"),
|
||||
Message.answer.ilike(f"%{args['keyword']}%"),
|
||||
Message.query.ilike(f"%{args.keyword}%"),
|
||||
Message.answer.ilike(f"%{args.keyword}%"),
|
||||
)
|
||||
)
|
||||
|
||||
account = current_user
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
assert account.timezone is not None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
@ -119,49 +380,46 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||
class CompletionConversationDetailApi(Resource):
|
||||
@api.doc("get_completion_conversation")
|
||||
@api.doc(description="Get completion conversation details with messages")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(200, "Success", conversation_message_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@console_ns.doc("get_completion_conversation")
|
||||
@console_ns.doc(description="Get completion conversation details with messages")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", conversation_message_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_message_detail_fields)
|
||||
@marshal_with(conversation_message_detail_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@api.doc("delete_completion_conversation")
|
||||
@api.doc(description="Delete a completion conversation")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(204, "Conversation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@console_ns.doc(description="Delete a completion conversation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@ -171,63 +429,21 @@ class CompletionConversationDetailApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
|
||||
class ChatConversationApi(Resource):
|
||||
@api.doc("list_chat_conversations")
|
||||
@api.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
default="-updated_at",
|
||||
help="Sort field and direction",
|
||||
)
|
||||
)
|
||||
@api.response(200, "Success", conversation_with_summary_pagination_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("list_chat_conversations")
|
||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_with_summary_pagination_fields)
|
||||
@marshal_with(conversation_with_summary_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
@ -239,8 +455,8 @@ class ChatConversationApi(Resource):
|
||||
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
@ -260,58 +476,51 @@ class ChatConversationApi(Resource):
|
||||
)
|
||||
|
||||
account = current_user
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
assert account.timezone is not None
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
match args["sort_by"]:
|
||||
if start_datetime_utc:
|
||||
match args.sort_by:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
match args["sort_by"]:
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args.sort_by:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
if args.annotation_status == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
||||
if args.message_count_gte and args.message_count_gte >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages)) # type: ignore
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
.having(func.count(Message.id) >= args.message_count_gte)
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
|
||||
match args["sort_by"]:
|
||||
match args.sort_by:
|
||||
case "created_at":
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case "-created_at":
|
||||
@ -323,49 +532,46 @@ class ChatConversationApi(Resource):
|
||||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||
class ChatConversationDetailApi(Resource):
|
||||
@api.doc("get_chat_conversation")
|
||||
@api.doc(description="Get chat conversation details")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(200, "Success", conversation_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@console_ns.doc("get_chat_conversation")
|
||||
@console_ns.doc(description="Get chat conversation details")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", conversation_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_detail_fields)
|
||||
@marshal_with(conversation_detail_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@api.doc("delete_chat_conversation")
|
||||
@api.doc(description="Delete a chat conversation")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(204, "Conversation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@console_ns.doc(description="Delete a chat conversation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@ -374,6 +580,7 @@ class ChatConversationDetailApi(Resource):
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
|
||||
@ -1,47 +1,68 @@
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import paginated_conversation_variable_fields
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
paginated_conversation_variable_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
|
||||
paginated_conversation_variable_fields_copy["data"] = fields.List(
|
||||
fields.Nested(conversation_variable_model), attribute="data"
|
||||
)
|
||||
paginated_conversation_variable_model = console_ns.model(
|
||||
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/conversation-variables")
|
||||
class ConversationVariablesApi(Resource):
|
||||
@api.doc("get_conversation_variables")
|
||||
@api.doc(description="Get conversation variables for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
|
||||
)
|
||||
)
|
||||
@api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields)
|
||||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_fields)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
def get(self, app_model):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.order_by(ConversationVariable.created_at)
|
||||
)
|
||||
if args["conversation_id"]:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
||||
else:
|
||||
raise ValueError("conversation_id is required")
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
|
||||
|
||||
# NOTE: This is a temporary solution to avoid performance issues.
|
||||
page = 1
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -12,50 +13,80 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||
|
||||
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||
current: str = Field(default="", description="Current instruction text")
|
||||
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
|
||||
|
||||
class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(RuleGeneratePayload)
|
||||
reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
class RuleGenerateApi(Resource):
|
||||
@api.doc("generate_rule_config")
|
||||
@api.doc(description="Generate rule configuration using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RuleGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Rule generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Rule configuration generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@console_ns.doc("generate_rule_config")
|
||||
@console_ns.doc(description="Generate rule configuration using LLM")
|
||||
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Rule configuration generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -71,42 +102,25 @@ class RuleGenerateApi(Resource):
|
||||
|
||||
@console_ns.route("/rule-code-generate")
|
||||
class RuleCodeGenerateApi(Resource):
|
||||
@api.doc("generate_rule_code")
|
||||
@api.doc(description="Generate code rules using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RuleCodeGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Code generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
"code_language": fields.String(
|
||||
default="javascript", description="Programming language for code generation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Code rules generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@console_ns.doc("generate_rule_code")
|
||||
@console_ns.doc(description="Generate code rules using LLM")
|
||||
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Code rules generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
args = parser.parse_args()
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -122,35 +136,24 @@ class RuleCodeGenerateApi(Resource):
|
||||
|
||||
@console_ns.route("/rule-structured-output-generate")
|
||||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@api.doc("generate_structured_output")
|
||||
@api.doc(description="Generate structured output rules using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"StructuredOutputGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Structured output generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Structured output generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@console_ns.doc("generate_structured_output")
|
||||
@console_ns.doc(description="Generate structured output rules using LLM")
|
||||
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
|
||||
@console_ns.response(200, "Structured output generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -166,101 +169,79 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
|
||||
@console_ns.route("/instruction-generate")
|
||||
class InstructionGenerateApi(Resource):
|
||||
@api.doc("generate_instruction")
|
||||
@api.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InstructionGenerateRequest",
|
||||
{
|
||||
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
|
||||
"node_id": fields.String(description="Node ID for workflow context"),
|
||||
"current": fields.String(description="Current instruction text"),
|
||||
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
|
||||
"instruction": fields.String(required=True, description="Instruction for generation"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Instruction generated successfully")
|
||||
@api.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@console_ns.doc("generate_instruction")
|
||||
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Instruction generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||
parser.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("current", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args['flow_id']} not found"}, 400
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
node = [node for node in nodes if node["id"] == args["node_id"]]
|
||||
node = [node for node in nodes if node["id"] == args.node_id]
|
||||
if len(node) == 0:
|
||||
return {"error": f"node {args['node_id']} not found"}, 400
|
||||
return {"error": f"node {args.node_id} not found"}, 400
|
||||
node_type = node[0]["data"]["type"]
|
||||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_user.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
)
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
if args.node_id != "" and args.current != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
node_id=args.node_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
workflow_service=WorkflowService(),
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
@ -276,27 +257,17 @@ class InstructionGenerateApi(Resource):
|
||||
|
||||
@console_ns.route("/instruction-generate/template")
|
||||
class InstructionGenerationTemplateApi(Resource):
|
||||
@api.doc("get_instruction_template")
|
||||
@api.doc(description="Get instruction generation template")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InstructionTemplateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Template instruction"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Template retrieved successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.doc("get_instruction_template")
|
||||
@console_ns.doc(description="Get instruction generation template")
|
||||
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
|
||||
@console_ns.response(200, "Template retrieved successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
match args["type"]:
|
||||
args = InstructionTemplatePayload.model_validate(console_ns.payload)
|
||||
match args.type:
|
||||
case "prompt":
|
||||
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
|
||||
|
||||
@ -306,4 +277,4 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args['type']}")
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
||||
@ -1,18 +1,20 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
@ -21,24 +23,24 @@ class AppMCPServerStatus(StrEnum):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
class AppMCPServerController(Resource):
|
||||
@api.doc("get_app_mcp_server")
|
||||
@api.doc(description="Get MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
|
||||
@setup_required
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
@console_ns.doc(description="Get MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
@marshal_with(app_server_model)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@api.doc("create_app_mcp_server")
|
||||
@api.doc(description="Create MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerCreateRequest",
|
||||
{
|
||||
"description": fields.String(description="Server description"),
|
||||
@ -46,19 +48,21 @@ class AppMCPServerController(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "MCP server configuration created successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
@login_required
|
||||
@setup_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
description = args.get("description")
|
||||
@ -71,18 +75,18 @@ class AppMCPServerController(Resource):
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
@api.doc("update_app_mcp_server")
|
||||
@api.doc(description="Update MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"MCPServerUpdateRequest",
|
||||
{
|
||||
"id": fields.String(required=True, description="Server ID"),
|
||||
@ -92,22 +96,23 @@ class AppMCPServerController(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("id", type=str, required=True, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
.add_argument("status", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
@ -132,23 +137,23 @@ class AppMCPServerController(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@api.doc("refresh_app_mcp_server")
|
||||
@api.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@api.doc(params={"server_id": "Server ID"})
|
||||
@api.response(200, "MCP server refreshed successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Server not found")
|
||||
@console_ns.doc("refresh_app_mcp_server")
|
||||
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@console_ns.doc(params={"server_id": "Server ID"})
|
||||
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
def get(self, server_id):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id)
|
||||
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
|
||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not server:
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -16,74 +18,233 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("first_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class FeedbackExportQuery(BaseModel):
|
||||
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
|
||||
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
|
||||
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
|
||||
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
|
||||
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
|
||||
|
||||
@field_validator("has_comment", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool | None:
|
||||
if isinstance(value, bool) or value is None:
|
||||
return value
|
||||
lowered = value.lower()
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message infinite scroll pagination model
|
||||
message_infinite_scroll_pagination_model = console_ns.model(
|
||||
"MessageInfiniteScrollPagination",
|
||||
{
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_model)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
class ChatMessageListApi(Resource):
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_fields)),
|
||||
}
|
||||
|
||||
@api.doc("list_chat_messages")
|
||||
@api.doc(description="Get chat messages for a conversation with pagination")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
|
||||
)
|
||||
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@console_ns.doc("list_chat_messages")
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args["first_id"]:
|
||||
if args.first_id:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -98,7 +259,7 @@ class ChatMessageListApi(Resource):
|
||||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
@ -106,12 +267,12 @@ class ChatMessageListApi(Resource):
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args["limit"])
|
||||
.limit(args.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args["limit"]:
|
||||
if len(history_messages) == args.limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
@ -129,40 +290,28 @@ class ChatMessageListApi(Resource):
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
class MessageFeedbackApi(Resource):
|
||||
@api.doc("create_message_feedback")
|
||||
@api.doc(description="Create or update message feedback (like/dislike)")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageFeedbackRequest",
|
||||
{
|
||||
"message_id": fields.String(required=True, description="Message ID"),
|
||||
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Feedback updated successfully")
|
||||
@api.response(404, "Message not found")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@console_ns.doc("create_message_feedback")
|
||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback updated successfully")
|
||||
@console_ns.response(404, "Message not found")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model):
|
||||
if current_user is None:
|
||||
raise Forbidden()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
args = parser.parse_args()
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
message_id = str(args.message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
@ -171,18 +320,21 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args["rating"] and feedback:
|
||||
if not args.rating and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
rating_value = args.rating
|
||||
if rating_value is None:
|
||||
raise ValueError("rating is required to create feedback")
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=args["rating"],
|
||||
rating=rating_value,
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
@ -193,56 +345,15 @@ class MessageFeedbackApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class MessageAnnotationApi(Resource):
|
||||
@api.doc("create_message_annotation")
|
||||
@api.doc(description="Create message annotation")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@get_app_model
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@api.doc("get_annotation_count")
|
||||
@api.doc(description="Get count of message annotations for the app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(
|
||||
@console_ns.doc("get_annotation_count")
|
||||
@console_ns.doc(description="Get count of message annotations for the app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotation count retrieved successfully",
|
||||
api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -256,20 +367,23 @@ class MessageAnnotationCountApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
|
||||
class MessageSuggestedQuestionApi(Resource):
|
||||
@api.doc("get_message_suggested_questions")
|
||||
@api.doc(description="Get suggested questions for a message")
|
||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@api.response(
|
||||
@console_ns.doc("get_message_suggested_questions")
|
||||
@console_ns.doc(description="Get suggested questions for a message")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Suggested questions retrieved successfully",
|
||||
api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}),
|
||||
console_ns.model(
|
||||
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
|
||||
),
|
||||
)
|
||||
@api.response(404, "Message or conversation not found")
|
||||
@console_ns.response(404, "Message or conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
@ -297,19 +411,59 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
|
||||
class MessageApi(Resource):
|
||||
@api.doc("get_message")
|
||||
@api.doc(description="Get message details by ID")
|
||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@api.response(200, "Message retrieved successfully", message_detail_fields)
|
||||
@api.response(404, "Message not found")
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
|
||||
class MessageFeedbackExportApi(Resource):
|
||||
@console_ns.doc("export_feedbacks")
|
||||
@console_ns.doc(description="Export user feedback data for Google Sheets")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
|
||||
@console_ns.response(200, "Feedback data exported successfully")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
try:
|
||||
export_data = FeedbackService.export_feedbacks(
|
||||
app_id=app_model.id,
|
||||
from_source=args.from_source,
|
||||
rating=args.rating,
|
||||
has_comment=args.has_comment,
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
format_type=args.format,
|
||||
)
|
||||
|
||||
return export_data
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Parameter validation error in feedback export")
|
||||
return {"error": f"Parameter validation error: {str(e)}"}, 400
|
||||
except Exception as e:
|
||||
logger.exception("Error exporting feedback data")
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
|
||||
class MessageApi(Resource):
|
||||
@console_ns.doc("get_message")
|
||||
@console_ns.doc(description="Get message details by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
|
||||
@console_ns.response(404, "Message not found")
|
||||
@get_app_model
|
||||
@marshal_with(message_detail_fields)
|
||||
def get(self, app_model, message_id):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(message_detail_model)
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
@ -2,31 +2,29 @@ import json
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/model-config")
|
||||
class ModelConfigResource(Resource):
|
||||
@api.doc("update_app_model_config")
|
||||
@api.doc(description="Update application model configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_model_config")
|
||||
@console_ns.doc(description="Update application model configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ModelConfigRequest",
|
||||
{
|
||||
"provider": fields.String(description="Model provider"),
|
||||
@ -44,25 +42,20 @@ class ModelConfigResource(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Model configuration updated successfully")
|
||||
@api.response(400, "Invalid configuration")
|
||||
@api.response(404, "App not found")
|
||||
@console_ns.response(200, "Model configuration updated successfully")
|
||||
@console_ns.response(400, "Invalid configuration")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
config=cast(dict, request.json),
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
@ -90,16 +83,16 @@ class ModelConfigResource(Resource):
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
@ -124,7 +117,7 @@ class ModelConfigResource(Resource):
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
|
||||
# get tool
|
||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||
@ -133,7 +126,7 @@ class ModelConfigResource(Resource):
|
||||
else:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
@ -141,7 +134,7 @@ class ModelConfigResource(Resource):
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
@ -172,6 +165,8 @@ class ModelConfigResource(Resource):
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
@ -14,24 +14,23 @@ class TraceAppConfigApi(Resource):
|
||||
Manage trace app configurations
|
||||
"""
|
||||
|
||||
@api.doc("get_trace_app_config")
|
||||
@api.doc(description="Get tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
@console_ns.doc("get_trace_app_config")
|
||||
@console_ns.doc(description="Get tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
||||
)
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -42,11 +41,11 @@ class TraceAppConfigApi(Resource):
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@api.doc("create_trace_app_config")
|
||||
@api.doc(description="Create a new tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_trace_app_config")
|
||||
@console_ns.doc(description="Create a new tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigCreateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
@ -54,18 +53,20 @@ class TraceAppConfigApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
||||
)
|
||||
@api.response(400, "Invalid request parameters or configuration already exists")
|
||||
@console_ns.response(400, "Invalid request parameters or configuration already exists")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
"""Create a new trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -80,11 +81,11 @@ class TraceAppConfigApi(Resource):
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@api.doc("update_trace_app_config")
|
||||
@api.doc(description="Update an existing tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_trace_app_config")
|
||||
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"TraceConfigUpdateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
@ -92,16 +93,18 @@ class TraceAppConfigApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@api.response(400, "Invalid request parameters or configuration not found")
|
||||
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, app_id):
|
||||
"""Update an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -114,23 +117,22 @@ class TraceAppConfigApi(Resource):
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@api.doc("delete_trace_app_config")
|
||||
@api.doc(description="Delete an existing tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
@console_ns.doc("delete_trace_app_config")
|
||||
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@api.response(204, "Tracing configuration deleted successfully")
|
||||
@api.response(400, "Invalid request parameters or configuration not found")
|
||||
@console_ns.response(204, "Tracing configuration deleted successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
"""Delete an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -1,48 +1,61 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Account, Site
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("title", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
parser.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
parser.add_argument("copyright", type=str, required=False, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
parser.add_argument(
|
||||
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("title", type=str, required=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, location="json")
|
||||
.add_argument("icon", type=str, required=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
.add_argument("copyright", type=str, required=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
.add_argument(
|
||||
"customize_token_strategy",
|
||||
type=str,
|
||||
choices=["must", "allow", "not_allow"],
|
||||
required=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
)
|
||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
class AppSite(Resource):
|
||||
@api.doc("update_app_site")
|
||||
@api.doc(description="Update application site configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_app_site")
|
||||
@console_ns.doc(description="Update application site configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"AppSiteRequest",
|
||||
{
|
||||
"title": fields.String(description="Site title"),
|
||||
@ -66,21 +79,18 @@ class AppSite(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Site configuration updated successfully", app_site_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "App not found")
|
||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_fields)
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if not site:
|
||||
raise NotFound
|
||||
@ -107,8 +117,6 @@ class AppSite(Resource):
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@ -118,30 +126,26 @@ class AppSite(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
|
||||
class AppSiteAccessTokenReset(Resource):
|
||||
@api.doc("reset_app_site_access_token")
|
||||
@api.doc(description="Reset access token for application site")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Access token reset successfully", app_site_fields)
|
||||
@api.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@api.response(404, "App or site not found")
|
||||
@console_ns.doc("reset_app_site_access_token")
|
||||
@console_ns.doc(description="Reset access token for application site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Access token reset successfully", app_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@console_ns.response(404, "App or site not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_fields)
|
||||
@marshal_with(app_site_model)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
site.code = Site.generate_code(16)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
@ -1,33 +1,48 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models import AppMode, Message
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
StatisticTimeRangeQuery.__name__,
|
||||
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
class DailyMessageStatistic(Resource):
|
||||
@api.doc("get_daily_message_statistics")
|
||||
@api.doc(description="Get daily message statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("get_daily_message_statistics")
|
||||
@console_ns.doc(description="Get daily message statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily message statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily message count data")),
|
||||
@ -37,43 +52,32 @@ class DailyMessageStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(*) AS message_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -91,15 +95,11 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@api.doc("get_daily_conversation_statistics")
|
||||
@api.doc(description="Get daily conversation statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("get_daily_conversation_statistics")
|
||||
@console_ns.doc(description="Get daily conversation statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily conversation count data")),
|
||||
@ -109,63 +109,53 @@ class DailyConversationStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
stmt = (
|
||||
sa.select(
|
||||
sa.func.date(
|
||||
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
|
||||
).label("date"),
|
||||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(stmt, {"tz": account.timezone})
|
||||
for row in rs:
|
||||
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users")
|
||||
class DailyTerminalsStatistic(Resource):
|
||||
@api.doc("get_daily_terminals_statistics")
|
||||
@api.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("get_daily_terminals_statistics")
|
||||
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily terminal count data")),
|
||||
@ -175,43 +165,32 @@ class DailyTerminalsStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -229,15 +208,11 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/token-costs")
|
||||
class DailyTokenCostStatistic(Resource):
|
||||
@api.doc("get_daily_token_cost_statistics")
|
||||
@api.doc(description="Get daily token cost statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("get_daily_token_cost_statistics")
|
||||
@console_ns.doc(description="Get daily token cost statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily token cost data")),
|
||||
@ -247,15 +222,13 @@ class DailyTokenCostStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
|
||||
SUM(total_price) AS total_price
|
||||
FROM
|
||||
@ -263,28 +236,19 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -304,15 +268,11 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions")
|
||||
class AverageSessionInteractionStatistic(Resource):
|
||||
@api.doc("get_average_session_interaction_statistics")
|
||||
@api.doc(description="Get average session interaction statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("get_average_session_interaction_statistics")
|
||||
@console_ns.doc(description="Get average session interaction statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Average session interaction data")),
|
||||
@ -322,15 +282,13 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM
|
||||
(
|
||||
@ -345,28 +303,19 @@ FROM
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -395,15 +344,11 @@ ORDER BY
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
|
||||
class UserSatisfactionRateStatistic(Resource):
|
||||
@api.doc("get_user_satisfaction_rate_statistics")
|
||||
@api.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("get_user_satisfaction_rate_statistics")
|
||||
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="User satisfaction rate data")),
|
||||
@ -413,15 +358,13 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(m.id) AS message_count,
|
||||
COUNT(mf.id) AS feedback_count
|
||||
FROM
|
||||
@ -432,28 +375,19 @@ LEFT JOIN
|
||||
WHERE
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -476,15 +410,11 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time")
|
||||
class AverageResponseTimeStatistic(Resource):
|
||||
@api.doc("get_average_response_time_statistics")
|
||||
@api.doc(description="Get average response time statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("get_average_response_time_statistics")
|
||||
@console_ns.doc(description="Get average response time statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Average response time data")),
|
||||
@ -494,43 +424,32 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
AVG(provider_response_latency) AS latency
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -548,15 +467,11 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second")
|
||||
class TokensPerSecondStatistic(Resource):
|
||||
@api.doc("get_tokens_per_second_statistics")
|
||||
@api.doc(description="Get tokens per second statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.doc("get_tokens_per_second_statistics")
|
||||
@console_ns.doc(description="Get tokens per second statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Tokens per second data")),
|
||||
@ -566,15 +481,12 @@ class TokensPerSecondStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
@ -584,28 +496,19 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,82 +1,85 @@
|
||||
from datetime import datetime
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowAppLogQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||
status: WorkflowExecutionStatus | None = Field(
|
||||
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
|
||||
)
|
||||
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
|
||||
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
|
||||
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
|
||||
created_by_account: str | None = Field(default=None, description="Filter by account")
|
||||
detail: bool = Field(default=False, description="Whether to return detailed logs")
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
|
||||
@field_validator("created_at__before", "created_at__after", mode="before")
|
||||
@classmethod
|
||||
def parse_datetime(cls, value: str | None) -> datetime | None:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
return isoparse(value) # type: ignore
|
||||
|
||||
@field_validator("detail", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
lowered = value.lower()
|
||||
if lowered in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("Invalid boolean value for detail")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@api.doc("get_workflow_app_logs")
|
||||
@api.doc(description="Get workflow application execution logs")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(
|
||||
params={
|
||||
"keyword": "Search keyword for filtering logs",
|
||||
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
|
||||
"created_at__before": "Filter logs created before this timestamp",
|
||||
"created_at__after": "Filter logs created after this timestamp",
|
||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
||||
"created_by_account": "Filter by account",
|
||||
"page": "Page number (1-99999)",
|
||||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields)
|
||||
@console_ns.doc("get_workflow_app_logs")
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_app_log_pagination_fields)
|
||||
@marshal_with(workflow_app_log_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
@ -90,6 +93,7 @@ class WorkflowAppLogApi(Resource):
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
detail=args.detail,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
@ -1,17 +1,18 @@
|
||||
import logging
|
||||
from typing import NoReturn
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import NoReturn, ParamSpec, TypeVar
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
@ -21,9 +22,8 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -58,16 +58,18 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
@ -139,8 +141,42 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_draft_variable_without_value_model = console_ns.model(
|
||||
"WorkflowDraftVariableWithoutValue", _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS
|
||||
)
|
||||
|
||||
def _api_prerequisite(f):
|
||||
workflow_draft_variable_model = console_ns.model("WorkflowDraftVariable", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
workflow_draft_env_variable_model = console_ns.model("WorkflowDraftEnvVariable", _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)
|
||||
|
||||
workflow_draft_env_variable_list_fields_copy = _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS.copy()
|
||||
workflow_draft_env_variable_list_fields_copy["items"] = fields.List(fields.Nested(workflow_draft_env_variable_model))
|
||||
workflow_draft_env_variable_list_model = console_ns.model(
|
||||
"WorkflowDraftEnvVariableList", workflow_draft_env_variable_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_draft_variable_list_without_value_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS.copy()
|
||||
workflow_draft_variable_list_without_value_fields_copy["items"] = fields.List(
|
||||
fields.Nested(workflow_draft_variable_without_value_model), attribute=_get_items
|
||||
)
|
||||
workflow_draft_variable_list_without_value_model = console_ns.model(
|
||||
"WorkflowDraftVariableListWithoutValue", workflow_draft_variable_list_without_value_fields_copy
|
||||
)
|
||||
|
||||
workflow_draft_variable_list_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS.copy()
|
||||
workflow_draft_variable_list_fields_copy["items"] = fields.List(
|
||||
fields.Nested(workflow_draft_variable_model), attribute=_get_items
|
||||
)
|
||||
workflow_draft_variable_list_model = console_ns.model(
|
||||
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _api_prerequisite(f: Callable[P, R]):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -154,11 +190,10 @@ def _api_prerequisite(f):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@ -166,13 +201,16 @@ def _api_prerequisite(f):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@api.doc("get_workflow_variables")
|
||||
@api.doc(description="Get draft workflow variables")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
|
||||
@api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
@console_ns.expect(_create_pagination_parser())
|
||||
@console_ns.doc("get_workflow_variables")
|
||||
@console_ns.doc(description="Get draft workflow variables")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.response(
|
||||
200, "Workflow variables retrieved successfully", workflow_draft_variable_list_without_value_model
|
||||
)
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
@ -199,9 +237,9 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@api.doc("delete_workflow_variables")
|
||||
@api.doc(description="Delete all draft workflow variables")
|
||||
@api.response(204, "Workflow variables deleted successfully")
|
||||
@console_ns.doc("delete_workflow_variables")
|
||||
@console_ns.doc(description="Delete all draft workflow variables")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -232,12 +270,12 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class NodeVariableCollectionApi(Resource):
|
||||
@api.doc("get_node_variables")
|
||||
@api.doc(description="Get variables for a specific node")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@console_ns.doc("get_node_variables")
|
||||
@console_ns.doc(description="Get variables for a specific node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
@ -248,9 +286,9 @@ class NodeVariableCollectionApi(Resource):
|
||||
|
||||
return node_vars
|
||||
|
||||
@api.doc("delete_node_variables")
|
||||
@api.doc(description="Delete all variables for a specific node")
|
||||
@api.response(204, "Node variables deleted successfully")
|
||||
@console_ns.doc("delete_node_variables")
|
||||
@console_ns.doc(description="Delete all variables for a specific node")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
@ -265,13 +303,13 @@ class VariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@api.doc("get_variable")
|
||||
@api.doc(description="Get a specific workflow variable")
|
||||
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@api.response(404, "Variable not found")
|
||||
@console_ns.doc("get_variable")
|
||||
@console_ns.doc(description="Get a specific workflow variable")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
@ -283,10 +321,10 @@ class VariableApi(Resource):
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@api.doc("update_variable")
|
||||
@api.doc(description="Update a workflow variable")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_variable")
|
||||
@console_ns.doc(description="Update a workflow variable")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateVariableRequest",
|
||||
{
|
||||
"name": fields.String(description="Variable name"),
|
||||
@ -294,10 +332,10 @@ class VariableApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@api.response(404, "Variable not found")
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
@ -320,10 +358,11 @@ class VariableApi(Resource):
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
@ -358,10 +397,10 @@ class VariableApi(Resource):
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@api.doc("delete_variable")
|
||||
@api.doc(description="Delete a workflow variable")
|
||||
@api.response(204, "Variable deleted successfully")
|
||||
@api.response(404, "Variable not found")
|
||||
@console_ns.doc("delete_variable")
|
||||
@console_ns.doc(description="Delete a workflow variable")
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -379,12 +418,12 @@ class VariableApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class VariableResetApi(Resource):
|
||||
@api.doc("reset_variable")
|
||||
@api.doc(description="Reset a workflow variable to its default value")
|
||||
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@api.response(204, "Variable reset (no content)")
|
||||
@api.response(404, "Variable not found")
|
||||
@console_ns.doc("reset_variable")
|
||||
@console_ns.doc(description="Reset a workflow variable to its default value")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -408,7 +447,7 @@ class VariableResetApi(Resource):
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
@ -427,13 +466,13 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
class ConversationVariableCollectionApi(Resource):
|
||||
@api.doc("get_conversation_variables")
|
||||
@api.doc(description="Get conversation variables for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@api.response(404, "Draft workflow not found")
|
||||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
@ -449,23 +488,23 @@ class ConversationVariableCollectionApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@api.doc("get_system_variables")
|
||||
@api.doc(description="Get system variables for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@console_ns.doc("get_system_variables")
|
||||
@console_ns.doc(description="Get system variables for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
class EnvironmentVariableCollectionApi(Resource):
|
||||
@api.doc("get_environment_variables")
|
||||
@api.doc(description="Get environment variables for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Environment variables retrieved successfully")
|
||||
@api.response(404, "Draft workflow not found")
|
||||
@console_ns.doc("get_environment_variables")
|
||||
@console_ns.doc(description="Get environment variables for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
|
||||
@ -1,90 +1,341 @@
|
||||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_run_fields import (
|
||||
advanced_chat_workflow_run_for_list_fields,
|
||||
advanced_chat_workflow_run_pagination_fields,
|
||||
workflow_run_count_fields,
|
||||
workflow_run_detail_fields,
|
||||
workflow_run_for_list_fields,
|
||||
workflow_run_node_execution_fields,
|
||||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models import Account, App, AppMode, EndUser
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
# Models that depend on simple_account_fields
|
||||
workflow_run_for_list_fields_copy = workflow_run_for_list_fields.copy()
|
||||
workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_for_list_model = console_ns.model("WorkflowRunForList", workflow_run_for_list_fields_copy)
|
||||
|
||||
advanced_chat_workflow_run_for_list_fields_copy = advanced_chat_workflow_run_for_list_fields.copy()
|
||||
advanced_chat_workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
advanced_chat_workflow_run_for_list_model = console_ns.model(
|
||||
"AdvancedChatWorkflowRunForList", advanced_chat_workflow_run_for_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_detail_fields_copy = workflow_run_detail_fields.copy()
|
||||
workflow_run_detail_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_detail_fields_copy["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
workflow_run_detail_model = console_ns.model("WorkflowRunDetail", workflow_run_detail_fields_copy)
|
||||
|
||||
workflow_run_node_execution_fields_copy = workflow_run_node_execution_fields.copy()
|
||||
workflow_run_node_execution_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_node_execution_fields_copy["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
workflow_run_node_execution_model = console_ns.model(
|
||||
"WorkflowRunNodeExecution", workflow_run_node_execution_fields_copy
|
||||
)
|
||||
|
||||
# Simple models without nested dependencies
|
||||
workflow_run_count_model = console_ns.model("WorkflowRunCount", workflow_run_count_fields)
|
||||
|
||||
# Pagination models that depend on list models
|
||||
advanced_chat_workflow_run_pagination_fields_copy = advanced_chat_workflow_run_pagination_fields.copy()
|
||||
advanced_chat_workflow_run_pagination_fields_copy["data"] = fields.List(
|
||||
fields.Nested(advanced_chat_workflow_run_for_list_model), attribute="data"
|
||||
)
|
||||
advanced_chat_workflow_run_pagination_model = console_ns.model(
|
||||
"AdvancedChatWorkflowRunPagination", advanced_chat_workflow_run_pagination_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_pagination_fields_copy = workflow_run_pagination_fields.copy()
|
||||
workflow_run_pagination_fields_copy["data"] = fields.List(fields.Nested(workflow_run_for_list_model), attribute="data")
|
||||
workflow_run_pagination_model = console_ns.model("WorkflowRunPagination", workflow_run_pagination_fields_copy)
|
||||
|
||||
workflow_run_node_execution_list_fields_copy = workflow_run_node_execution_list_fields.copy()
|
||||
workflow_run_node_execution_list_fields_copy["data"] = fields.List(fields.Nested(workflow_run_node_execution_model))
|
||||
workflow_run_node_execution_list_model = console_ns.model(
|
||||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunListQuery(BaseModel):
|
||||
last_id: str | None = Field(default=None, description="Last run ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
|
||||
@field_validator("last_id")
|
||||
@classmethod
|
||||
def validate_last_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class WorkflowRunCountQuery(BaseModel):
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
|
||||
@field_validator("time_range")
|
||||
@classmethod
|
||||
def validate_time_range(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return time_duration(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowRunCountQuery.__name__,
|
||||
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@api.doc("get_advanced_chat_workflow_runs")
|
||||
@api.doc(description="Get advanced chat workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs")
|
||||
@console_ns.doc(description="Get advanced chat workflow run list")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_fields)
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs_count")
|
||||
@console_ns.doc(description="Get advanced chat workflow runs count statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status=args.get("status"),
|
||||
time_range=args.get("time_range"),
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs")
|
||||
class WorkflowRunListApi(Resource):
|
||||
@api.doc("get_workflow_runs")
|
||||
@api.doc(description="Get workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
|
||||
@console_ns.doc("get_workflow_runs")
|
||||
@console_ns.doc(description="Get workflow run list")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_pagination_fields)
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
|
||||
result = workflow_run_service.get_paginate_workflow_runs(
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
|
||||
class WorkflowRunCountApi(Resource):
|
||||
@console_ns.doc("get_workflow_runs_count")
|
||||
@console_ns.doc(description="Get workflow runs count statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status=args.get("status"),
|
||||
time_range=args.get("time_range"),
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
|
||||
class WorkflowRunDetailApi(Resource):
|
||||
@api.doc("get_workflow_run_detail")
|
||||
@api.doc(description="Get workflow run detail")
|
||||
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@api.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields)
|
||||
@api.response(404, "Workflow run not found")
|
||||
@console_ns.doc("get_workflow_run_detail")
|
||||
@console_ns.doc(description="Get workflow run detail")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_detail_fields)
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run detail
|
||||
@ -99,16 +350,16 @@ class WorkflowRunDetailApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@api.doc("get_workflow_run_node_executions")
|
||||
@api.doc(description="Get workflow run node execution list")
|
||||
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@api.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields)
|
||||
@api.response(404, "Workflow run not found")
|
||||
@console_ns.doc("get_workflow_run_node_executions")
|
||||
@console_ns.doc(description="Get workflow run node execution list")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_list_fields)
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
|
||||
@ -1,311 +1,194 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowStatisticQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowStatisticQuery.__name__,
|
||||
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
@api.doc("get_workflow_daily_runs_statistic")
|
||||
@api.doc(description="Get workflow daily runs statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Daily runs statistics retrieved successfully")
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@console_ns.doc("get_workflow_daily_runs_statistic")
|
||||
@console_ns.doc(description="Get workflow daily runs statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily runs statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(id) AS runs
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
response_data = self._workflow_run_repo.get_daily_runs_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||
class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@api.doc("get_workflow_daily_terminals_statistic")
|
||||
@api.doc(description="Get workflow daily terminals statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Daily terminals statistics retrieved successfully")
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@console_ns.doc("get_workflow_daily_terminals_statistic")
|
||||
@console_ns.doc(description="Get workflow daily terminals statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||
class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@api.doc("get_workflow_daily_token_cost_statistic")
|
||||
@api.doc(description="Get workflow daily token cost statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Daily token cost statistics retrieved successfully")
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@console_ns.doc("get_workflow_daily_token_cost_statistic")
|
||||
@console_ns.doc(description="Get workflow daily token cost statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) AS token_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
"date": str(i.date),
|
||||
"token_count": i.token_count,
|
||||
}
|
||||
)
|
||||
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
|
||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@api.doc("get_workflow_average_app_interaction_statistic")
|
||||
@api.doc(description="Get workflow average app interaction statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Average app interaction statistics retrieved successfully")
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@console_ns.doc("get_workflow_average_app_interaction_statistic")
|
||||
@console_ns.doc(description="Get workflow average app interaction statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
sql_query = """SELECT
|
||||
AVG(sub.interactions) AS interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM
|
||||
workflow_runs c
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY
|
||||
date, c.created_by
|
||||
) sub
|
||||
GROUP BY
|
||||
sub.date"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{start}}", "")
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
)
|
||||
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
157
api/controllers/console/app/workflow_trigger.py
Normal file
157
api/controllers/console/app/workflow_trigger.py
Normal file
@ -0,0 +1,157 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
from .. import console_ns
|
||||
from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class ParserEnable(BaseModel):
|
||||
trigger_id: str
|
||||
enable_trigger: bool
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.where(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/triggers")
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
trigger_id = args.trigger_id
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
@ -4,28 +4,29 @@ from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
P1 = ParamSpec("P1")
|
||||
R1 = TypeVar("R1")
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
assert isinstance(current_user, Account)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
|
||||
@ -2,35 +2,31 @@ from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||
from models.account import AccountStatus
|
||||
from models import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
active_check_parser = reqparse.RequestParser()
|
||||
active_check_parser.add_argument(
|
||||
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
|
||||
)
|
||||
active_check_parser.add_argument(
|
||||
"email", type=email, required=False, nullable=True, location="args", help="Email address"
|
||||
)
|
||||
active_check_parser.add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
|
||||
active_check_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
class ActivateCheckApi(Resource):
|
||||
@api.doc("check_activation_token")
|
||||
@api.doc(description="Check if activation token is valid")
|
||||
@api.expect(active_check_parser)
|
||||
@api.response(
|
||||
@console_ns.doc("check_activation_token")
|
||||
@console_ns.doc(description="Check if activation token is valid")
|
||||
@console_ns.expect(active_check_parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"ActivationCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
||||
@ -60,26 +56,26 @@ class ActivateCheckApi(Resource):
|
||||
return {"is_valid": False}
|
||||
|
||||
|
||||
active_parser = reqparse.RequestParser()
|
||||
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
active_parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
active_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
)
|
||||
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/activate")
|
||||
class ActivateApi(Resource):
|
||||
@api.doc("activate_account")
|
||||
@api.doc(description="Activate account with invitation token")
|
||||
@api.expect(active_parser)
|
||||
@api.response(
|
||||
@console_ns.doc("activate_account")
|
||||
@console_ns.doc(description="Activate account with invitation token")
|
||||
@console_ns.expect(active_parser)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
@ -87,7 +83,7 @@ class ActivateApi(Resource):
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Already activated or invalid token")
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
args = active_parser.parse_args()
|
||||
|
||||
@ -103,7 +99,7 @@ class ActivateApi(Resource):
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from libs.login import login_required
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
@ -16,7 +15,8 @@ class ApiKeyAuthDataSource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
"sources": [
|
||||
@ -39,18 +39,20 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
@ -61,11 +63,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -2,12 +2,11 @@ import logging
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
@ -30,23 +29,22 @@ def get_oauth_providers():
|
||||
|
||||
@console_ns.route("/oauth/data-source/<string:provider>")
|
||||
class OAuthDataSource(Resource):
|
||||
@api.doc("oauth_data_source")
|
||||
@api.doc(description="Get OAuth authorization URL for data source provider")
|
||||
@api.doc(params={"provider": "Data source provider name (notion)"})
|
||||
@api.response(
|
||||
@console_ns.doc("oauth_data_source")
|
||||
@console_ns.doc(description="Get OAuth authorization URL for data source provider")
|
||||
@console_ns.doc(params={"provider": "Data source provider name (notion)"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Authorization URL or internal setup success",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"OAuthDataSourceResponse",
|
||||
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid provider")
|
||||
@api.response(403, "Admin privileges required")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@is_admin_or_owner_required
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
@ -65,17 +63,17 @@ class OAuthDataSource(Resource):
|
||||
|
||||
@console_ns.route("/oauth/data-source/callback/<string:provider>")
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
@api.doc("oauth_data_source_callback")
|
||||
@api.doc(description="Handle OAuth callback from data source provider")
|
||||
@api.doc(
|
||||
@console_ns.doc("oauth_data_source_callback")
|
||||
@console_ns.doc(description="Handle OAuth callback from data source provider")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"provider": "Data source provider name (notion)",
|
||||
"code": "Authorization code from OAuth provider",
|
||||
"error": "Error message from OAuth provider",
|
||||
}
|
||||
)
|
||||
@api.response(302, "Redirect to console with result")
|
||||
@api.response(400, "Invalid provider")
|
||||
@console_ns.response(302, "Redirect to console with result")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
@ -96,17 +94,17 @@ class OAuthDataSourceCallback(Resource):
|
||||
|
||||
@console_ns.route("/oauth/data-source/binding/<string:provider>")
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
@api.doc("oauth_data_source_binding")
|
||||
@api.doc(description="Bind OAuth data source with authorization code")
|
||||
@api.doc(
|
||||
@console_ns.doc("oauth_data_source_binding")
|
||||
@console_ns.doc(description="Bind OAuth data source with authorization code")
|
||||
@console_ns.doc(
|
||||
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source binding success",
|
||||
api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Invalid provider or code")
|
||||
@console_ns.response(400, "Invalid provider or code")
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
@ -130,15 +128,15 @@ class OAuthDataSourceBinding(Resource):
|
||||
|
||||
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
||||
class OAuthDataSourceSync(Resource):
|
||||
@api.doc("oauth_data_source_sync")
|
||||
@api.doc(description="Sync data from OAuth data source")
|
||||
@api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
|
||||
@api.response(
|
||||
@console_ns.doc("oauth_data_source_sync")
|
||||
@console_ns.doc(description="Sync data from OAuth data source")
|
||||
@console_ns.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Data source sync success",
|
||||
api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Invalid provider or sync failed")
|
||||
@console_ns.response(400, "Invalid provider or sync failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
@ -31,9 +31,11 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
@ -59,10 +61,12 @@ class EmailRegisterCheckApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
@ -100,10 +104,12 @@ class EmailRegisterResetApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
|
||||
@ -6,7 +6,7 @@ from flask_restx import Resource, fields, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
@ -20,17 +20,17 @@ from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@api.doc("send_forgot_password_email")
|
||||
@api.doc(description="Send password reset email")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("send_forgot_password_email")
|
||||
@console_ns.doc(description="Send password reset email")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
@ -38,10 +38,10 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"ForgotPasswordEmailResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
@ -50,13 +50,15 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid email or rate limit exceeded")
|
||||
@console_ns.response(400, "Invalid email or rate limit exceeded")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
@ -83,10 +85,10 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
@console_ns.route("/forgot-password/validity")
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@api.doc("check_forgot_password_code")
|
||||
@api.doc(description="Verify password reset code")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("check_forgot_password_code")
|
||||
@console_ns.doc(description="Verify password reset code")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
@ -95,10 +97,10 @@ class ForgotPasswordCheckApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
api.model(
|
||||
console_ns.model(
|
||||
"ForgotPasswordCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether code is valid"),
|
||||
@ -107,14 +109,16 @@ class ForgotPasswordCheckApi(Resource):
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid code or token")
|
||||
@console_ns.response(400, "Invalid code or token")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
@ -148,10 +152,10 @@ class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
@console_ns.route("/forgot-password/resets")
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@api.doc("reset_password")
|
||||
@api.doc(description="Reset password with verification token")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("reset_password")
|
||||
@console_ns.doc(description="Reset password with verification token")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ForgotPasswordResetRequest",
|
||||
{
|
||||
"token": fields.String(required=True, description="Verification token"),
|
||||
@ -160,19 +164,21 @@ class ForgotPasswordResetApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@api.response(400, "Invalid token or password mismatch")
|
||||
@console_ns.response(400, "Invalid token or password mismatch")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from constants.languages import get_valid_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
@ -26,7 +24,16 @@ from controllers.console.error import (
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
clear_refresh_token_from_cookie,
|
||||
extract_refresh_token,
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountRegisterError
|
||||
@ -42,11 +49,13 @@ class LoginApi(Resource):
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("password", type=str, required=True, location="json")
|
||||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=str, required=True, location="json")
|
||||
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
@ -89,19 +98,36 @@ class LoginApi(Resource):
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/logout")
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
return {"result": "success"}
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
return {"result": "success"}
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
# Clear cookies on logout
|
||||
clear_access_token_from_cookie(response)
|
||||
clear_refresh_token_from_cookie(response)
|
||||
clear_csrf_token_from_cookie(response)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/reset-password")
|
||||
@ -109,9 +135,11 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
@ -137,9 +165,11 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
@ -170,13 +200,17 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
language = args["language"]
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
@ -210,7 +244,9 @@ class EmailCodeLoginApi(Resource):
|
||||
if account is None:
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
email=user_email,
|
||||
name=user_email,
|
||||
interface_language=get_valid_language(language),
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
@ -220,18 +256,36 @@ class EmailCodeLoginApi(Resource):
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
# Set HTTP-only secure cookies for tokens
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/refresh-token")
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
# Get refresh token from cookie instead of request body
|
||||
refresh_token = extract_refresh_token(request)
|
||||
|
||||
if not refresh_token:
|
||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||
new_token_pair = AccountService.refresh_token(refresh_token)
|
||||
|
||||
# Create response with new cookies
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
# Update cookies with new tokens
|
||||
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
|
||||
set_access_token_to_cookie(request, response, new_token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "data": str(e)}, 401
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
|
||||
@ -14,15 +14,19 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from libs.token import (
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from models import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import api, console_ns
|
||||
from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -52,11 +56,13 @@ def get_oauth_providers():
|
||||
|
||||
@console_ns.route("/oauth/login/<provider>")
|
||||
class OAuthLogin(Resource):
|
||||
@api.doc("oauth_login")
|
||||
@api.doc(description="Initiate OAuth login process")
|
||||
@api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
|
||||
@api.response(302, "Redirect to OAuth authorization URL")
|
||||
@api.response(400, "Invalid provider")
|
||||
@console_ns.doc("oauth_login")
|
||||
@console_ns.doc(description="Initiate OAuth login process")
|
||||
@console_ns.doc(
|
||||
params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}
|
||||
)
|
||||
@console_ns.response(302, "Redirect to OAuth authorization URL")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
invite_token = request.args.get("invite_token") or None
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
@ -71,17 +77,17 @@ class OAuthLogin(Resource):
|
||||
|
||||
@console_ns.route("/oauth/authorize/<provider>")
|
||||
class OAuthCallback(Resource):
|
||||
@api.doc("oauth_callback")
|
||||
@api.doc(description="Handle OAuth callback and complete login process")
|
||||
@api.doc(
|
||||
@console_ns.doc("oauth_callback")
|
||||
@console_ns.doc(description="Handle OAuth callback and complete login process")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"provider": "OAuth provider name (github/google)",
|
||||
"code": "Authorization code from OAuth provider",
|
||||
"state": "Optional state parameter (used for invite token)",
|
||||
}
|
||||
)
|
||||
@api.response(302, "Redirect to console with access token")
|
||||
@api.response(400, "OAuth process failed")
|
||||
@console_ns.response(302, "Redirect to console with access token")
|
||||
@console_ns.response(400, "OAuth process failed")
|
||||
def get(self, provider: str):
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
@ -130,11 +136,11 @@ class OAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
|
||||
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
if account.status == AccountStatus.BANNED:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
@ -153,9 +159,12 @@ class OAuthCallback(Resource):
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
||||
)
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
return response
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
import flask_login
|
||||
from flask import jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
@ -24,8 +23,7 @@ T = TypeVar("T")
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
@ -91,8 +89,7 @@ class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
|
||||
@ -116,7 +113,8 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
@account_initialization_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
@ -132,12 +130,14 @@ class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=False, location="json")
|
||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("grant_type", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=False, location="json")
|
||||
.add_argument("client_secret", type=str, required=False, location="json")
|
||||
.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
)
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
import base64
|
||||
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -14,17 +17,21 @@ class Subscription(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
args = parser.parse_args()
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_subscription(
|
||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/invoices")
|
||||
@ -34,7 +41,40 @@ class Invoices(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/partners/<string:partner_key>/tenants")
|
||||
class PartnerTenants(Resource):
|
||||
@console_ns.doc("sync_partner_tenants_bindings")
|
||||
@console_ns.doc(description="Sync partner tenants bindings")
|
||||
@console_ns.doc(params={"partner_key": "Partner key"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncPartnerTenantsBindingsRequest",
|
||||
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Tenants synced to partner successfully")
|
||||
@console_ns.response(400, "Invalid partner information")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
click_id = args["click_id"]
|
||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||
except Exception:
|
||||
raise BadRequest("Invalid partner_key")
|
||||
|
||||
if not click_id or not decoded_partner_key or not current_user.id:
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
@ -17,17 +16,16 @@ class ComplianceApi(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||
|
||||
return BillingService.get_compliance_download_link(
|
||||
doc_name=args.doc_name,
|
||||
account_id=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
ip=ip_address,
|
||||
device_info=device_info,
|
||||
)
|
||||
|
||||
@ -3,7 +3,6 @@ from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -15,12 +14,12 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType,
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
@ -37,10 +36,12 @@ class DataSourceApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.tenant_id == current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
).all()
|
||||
@ -120,13 +121,15 @@ class DataSourceNotionListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
@ -146,7 +149,7 @@ class DataSourceNotionListApi(Resource):
|
||||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
@ -161,7 +164,7 @@ class DataSourceNotionListApi(Resource):
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||
datasource_name="notion_datasource",
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -210,12 +213,14 @@ class DataSourceNotionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
@ -229,7 +234,7 @@ class DataSourceNotionApi(Resource):
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
text_docs = extractor.extract()
|
||||
@ -239,12 +244,14 @@ class DataSourceNotionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
@ -256,20 +263,22 @@ class DataSourceNotionApi(Resource):
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -6,13 +6,12 @@ from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
@ -44,18 +43,19 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
document_fields,
|
||||
document_metadata_fields,
|
||||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
@ -63,8 +63,39 @@ from services.entities.knowledge_entities.knowledge_entities import KnowledgeCon
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_model = _get_or_create_model("Dataset", dataset_fields)
|
||||
|
||||
document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
|
||||
|
||||
document_fields_copy = document_fields.copy()
|
||||
document_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_model = _get_or_create_model("Document", document_fields_copy)
|
||||
|
||||
document_with_segments_fields_copy = document_with_segments_fields.copy()
|
||||
document_with_segments_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
|
||||
|
||||
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
|
||||
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
|
||||
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
|
||||
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -79,12 +110,13 @@ class DocumentResource(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
if document.tenant_id != current_tenant_id:
|
||||
raise Forbidden("No permission.")
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -104,14 +136,15 @@ class DocumentResource(Resource):
|
||||
|
||||
@console_ns.route("/datasets/process-rule")
|
||||
class GetProcessRuleApi(Resource):
|
||||
@api.doc("get_process_rule")
|
||||
@api.doc(description="Get dataset document processing rules")
|
||||
@api.doc(params={"document_id": "Document ID (optional)"})
|
||||
@api.response(200, "Process rules retrieved successfully")
|
||||
@console_ns.doc("get_process_rule")
|
||||
@console_ns.doc(description="Get dataset document processing rules")
|
||||
@console_ns.doc(params={"document_id": "Document ID (optional)"})
|
||||
@console_ns.response(200, "Process rules retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get("document_id")
|
||||
@ -151,9 +184,9 @@ class GetProcessRuleApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
class DatasetDocumentListApi(Resource):
|
||||
@api.doc("get_dataset_documents")
|
||||
@api.doc(description="Get documents in a dataset")
|
||||
@api.doc(
|
||||
@console_ns.doc("get_dataset_documents")
|
||||
@console_ns.doc(description="Get documents in a dataset")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"page": "Page number (default: 1)",
|
||||
@ -161,18 +194,20 @@ class DatasetDocumentListApi(Resource):
|
||||
"keyword": "Search keyword",
|
||||
"sort": "Sort order (default: -created_at)",
|
||||
"fetch": "Fetch full details (default: false)",
|
||||
"status": "Filter documents by display status",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Documents retrieved successfully")
|
||||
@console_ns.response(200, "Documents retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
def get(self, dataset_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
@ -199,7 +234,10 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
@ -269,10 +307,11 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -289,23 +328,23 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("data_source", type=dict, required=False, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=False, location="json")
|
||||
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=False, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, location="json")
|
||||
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
|
||||
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
@ -349,10 +388,10 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/init")
|
||||
class DatasetInitApi(Resource):
|
||||
@api.doc("init_dataset")
|
||||
@api.doc(description="Initialize dataset with documents")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("init_dataset")
|
||||
@console_ns.doc(description="Initialize dataset with documents")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"DatasetInitRequest",
|
||||
{
|
||||
"upload_file_id": fields.String(required=True, description="Upload file ID"),
|
||||
@ -362,47 +401,48 @@ class DatasetInitApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "Dataset initialized successfully", dataset_and_document_fields)
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
required=True,
|
||||
nullable=False,
|
||||
location="json",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
required=True,
|
||||
nullable=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=args["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=args["embedding_model"],
|
||||
@ -419,9 +459,9 @@ class DatasetInitApi(Resource):
|
||||
|
||||
try:
|
||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
knowledge_config=knowledge_config,
|
||||
account=cast(Account, current_user),
|
||||
account=current_user,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -437,16 +477,17 @@ class DatasetInitApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
|
||||
class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@api.doc("estimate_document_indexing")
|
||||
@api.doc(description="Estimate document indexing cost")
|
||||
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@api.response(200, "Indexing estimate calculated successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@api.response(400, "Document already finished")
|
||||
@console_ns.doc("estimate_document_indexing")
|
||||
@console_ns.doc(description="Estimate document indexing cost")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Indexing estimate calculated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(400, "Document already finished")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
@ -475,14 +516,14 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
estimate_response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
[extract_setting],
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
@ -511,6 +552,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, batch):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
@ -530,7 +572,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -538,7 +580,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
@ -546,14 +588,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@ -561,15 +605,17 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
},
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@ -579,7 +625,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
extract_settings,
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
@ -646,11 +692,11 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
class DocumentIndexingStatusApi(DocumentResource):
|
||||
@api.doc("get_document_indexing_status")
|
||||
@api.doc(description="Get document indexing status")
|
||||
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@api.response(200, "Indexing status retrieved successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@console_ns.doc("get_document_indexing_status")
|
||||
@console_ns.doc(description="Get document indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -696,17 +742,17 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
class DocumentApi(DocumentResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@api.doc("get_document")
|
||||
@api.doc(description="Get document details")
|
||||
@api.doc(
|
||||
@console_ns.doc("get_document")
|
||||
@console_ns.doc(description="Get document details")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"document_id": "Document ID",
|
||||
"metadata": "Metadata inclusion (all/only/without)",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Document retrieved successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@console_ns.response(200, "Document retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -736,7 +782,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@ -769,7 +815,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@ -817,19 +863,20 @@ class DocumentApi(DocumentResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>")
|
||||
class DocumentProcessingApi(DocumentResource):
|
||||
@api.doc("update_document_processing")
|
||||
@api.doc(description="Update document processing status (pause/resume)")
|
||||
@api.doc(
|
||||
@console_ns.doc("update_document_processing")
|
||||
@console_ns.doc(description="Update document processing status (pause/resume)")
|
||||
@console_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"}
|
||||
)
|
||||
@api.response(200, "Processing status updated successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@api.response(400, "Invalid action")
|
||||
@console_ns.response(200, "Processing status updated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(400, "Invalid action")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
@ -861,11 +908,11 @@ class DocumentProcessingApi(DocumentResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||
class DocumentMetadataApi(DocumentResource):
|
||||
@api.doc("update_document_metadata")
|
||||
@api.doc(description="Update document metadata")
|
||||
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("update_document_metadata")
|
||||
@console_ns.doc(description="Update document metadata")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateDocumentMetadataRequest",
|
||||
{
|
||||
"doc_type": fields.String(description="Document type"),
|
||||
@ -873,13 +920,14 @@ class DocumentMetadataApi(DocumentResource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Document metadata updated successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.response(200, "Document metadata updated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
@ -927,6 +975,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
@ -1030,8 +1079,9 @@ class DocumentRetryApi(DocumentResource):
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"document_ids", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -1073,14 +1123,14 @@ class DocumentRenameApi(DocumentResource):
|
||||
@marshal_with(document_fields)
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -1098,6 +1148,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@ -1106,7 +1157,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
if document.tenant_id != current_tenant_id:
|
||||
raise Forbidden("No permission.")
|
||||
if document.data_source_type != "website_crawl":
|
||||
raise ValueError("Document is not a website document.")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@ -27,7 +26,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
@ -43,6 +42,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -59,13 +60,15 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("limit", type=int, default=20, location="args")
|
||||
parser.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
parser.add_argument("enabled", type=str, default="all", location="args")
|
||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||
parser.add_argument("page", type=int, default=1, location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
.add_argument("enabled", type=str, default="all", location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -79,7 +82,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
)
|
||||
@ -115,6 +118,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -148,6 +153,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@ -171,7 +178,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -204,6 +211,8 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -221,7 +230,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -237,10 +246,12 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
@ -255,6 +266,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -272,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -287,7 +300,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -300,16 +313,18 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
.add_argument(
|
||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@ -317,6 +332,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -333,7 +350,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -361,6 +378,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -372,8 +391,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"upload_file_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
upload_file_id = args["upload_file_id"]
|
||||
|
||||
@ -396,7 +416,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
upload_file_id,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -427,6 +447,8 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -441,7 +463,7 @@ class ChildChunkAddApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -453,7 +475,7 @@ class ChildChunkAddApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -469,8 +491,9 @@ class ChildChunkAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
content = args["content"]
|
||||
@ -483,6 +506,8 @@ class ChildChunkAddApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id, segment_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -499,15 +524,17 @@ class ChildChunkAddApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("limit", type=int, default=20, location="args")
|
||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||
parser.add_argument("page", type=int, default=1, location="args")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -530,6 +557,8 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -546,7 +575,7 @@ class ChildChunkAddApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -559,12 +588,13 @@ class ChildChunkAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"chunks", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
|
||||
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
@ -580,6 +610,8 @@ class ChildChunkUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -596,7 +628,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -607,7 +639,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
@ -634,6 +666,8 @@ class ChildChunkUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -650,7 +684,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -661,7 +695,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
@ -677,8 +711,9 @@ class ChildChunkUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
content = args["content"]
|
||||
|
||||
@ -1,23 +1,76 @@
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
doc_metadata_fields,
|
||||
external_knowledge_info_fields,
|
||||
external_retrieval_model_fields,
|
||||
icon_info_fields,
|
||||
keyword_setting_fields,
|
||||
reranking_model_fields,
|
||||
tag_fields,
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
def _build_dataset_detail_model():
|
||||
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
tag_model = _get_or_create_model("Tag", tag_fields)
|
||||
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
|
||||
try:
|
||||
dataset_detail_model = console_ns.models["DatasetDetail"]
|
||||
except KeyError:
|
||||
dataset_detail_model = _build_dataset_detail_model()
|
||||
|
||||
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
@ -26,26 +79,27 @@ def _validate_name(name: str) -> str:
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api")
|
||||
class ExternalApiTemplateListApi(Resource):
|
||||
@api.doc("get_external_api_templates")
|
||||
@api.doc(description="Get external knowledge API templates")
|
||||
@api.doc(
|
||||
@console_ns.doc("get_external_api_templates")
|
||||
@console_ns.doc(description="Get external knowledge API templates")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"page": "Page number (default: 1)",
|
||||
"limit": "Number of items per page (default: 20)",
|
||||
"keyword": "Search keyword",
|
||||
}
|
||||
)
|
||||
@api.response(200, "External API templates retrieved successfully")
|
||||
@console_ns.response(200, "External API templates retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
|
||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||
page, limit, current_user.current_tenant_id, search
|
||||
page, limit, current_tenant_id, search
|
||||
)
|
||||
response = {
|
||||
"data": [item.to_dict() for item in external_knowledge_apis],
|
||||
@ -60,20 +114,23 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -85,7 +142,7 @@ class ExternalApiTemplateListApi(Resource):
|
||||
|
||||
try:
|
||||
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, args=args
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
@ -95,11 +152,11 @@ class ExternalApiTemplateListApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
|
||||
class ExternalApiTemplateApi(Resource):
|
||||
@api.doc("get_external_api_template")
|
||||
@api.doc(description="Get external knowledge API template details")
|
||||
@api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@api.response(200, "External API template retrieved successfully")
|
||||
@api.response(404, "Template not found")
|
||||
@console_ns.doc("get_external_api_template")
|
||||
@console_ns.doc(description="Get external knowledge API template details")
|
||||
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@console_ns.response(200, "External API template retrieved successfully")
|
||||
@console_ns.response(404, "Template not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -115,28 +172,31 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
|
||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
args=args,
|
||||
@ -148,22 +208,22 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_operator):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
|
||||
class ExternalApiUseCheckApi(Resource):
|
||||
@api.doc("check_external_api_usage")
|
||||
@api.doc(description="Check if external knowledge API is being used")
|
||||
@api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@api.response(200, "Usage check completed successfully")
|
||||
@console_ns.doc("check_external_api_usage")
|
||||
@console_ns.doc(description="Check if external knowledge API is being used")
|
||||
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@console_ns.response(200, "Usage check completed successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -178,10 +238,10 @@ class ExternalApiUseCheckApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/external")
|
||||
class ExternalDatasetCreateApi(Resource):
|
||||
@api.doc("create_external_dataset")
|
||||
@api.doc(description="Create external knowledge dataset")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("create_external_dataset")
|
||||
@console_ns.doc(description="Create external knowledge dataset")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"CreateExternalDatasetRequest",
|
||||
{
|
||||
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
|
||||
@ -191,29 +251,30 @@ class ExternalDatasetCreateApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "External dataset created successfully", dataset_detail_fields)
|
||||
@api.response(400, "Invalid parameters")
|
||||
@api.response(403, "Permission denied")
|
||||
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -223,7 +284,7 @@ class ExternalDatasetCreateApi(Resource):
|
||||
|
||||
try:
|
||||
dataset = ExternalDatasetService.create_external_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
args=args,
|
||||
)
|
||||
@ -235,11 +296,11 @@ class ExternalDatasetCreateApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/external-hit-testing")
|
||||
class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@api.doc("test_external_knowledge_retrieval")
|
||||
@api.doc(description="Test external knowledge retrieval for dataset")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("test_external_knowledge_retrieval")
|
||||
@console_ns.doc(description="Test external knowledge retrieval for dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"ExternalHitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
@ -248,13 +309,14 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "External hit testing completed successfully")
|
||||
@api.response(404, "Dataset not found")
|
||||
@api.response(400, "Invalid parameters")
|
||||
@console_ns.response(200, "External hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -265,10 +327,12 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
@ -277,7 +341,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
response = HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=cast(Account, current_user),
|
||||
account=current_user,
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||
)
|
||||
@ -290,10 +354,10 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@console_ns.route("/test/retrieval")
|
||||
class BedrockRetrievalApi(Resource):
|
||||
# this api is only for internal testing
|
||||
@api.doc("bedrock_retrieval_test")
|
||||
@api.doc(description="Bedrock retrieval test (internal use only)")
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("bedrock_retrieval_test")
|
||||
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"BedrockRetrievalTestRequest",
|
||||
{
|
||||
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
|
||||
@ -302,17 +366,19 @@ class BedrockRetrievalApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Bedrock retrieval test completed")
|
||||
@console_ns.response(200, "Bedrock retrieval test completed")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||
parser.add_argument(
|
||||
"query",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||
.add_argument(
|
||||
"query",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||
)
|
||||
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Call the knowledge retrieval service
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -12,11 +12,11 @@ from libs.login import login_required
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@api.doc("test_dataset_retrieval")
|
||||
@api.doc(description="Test dataset knowledge retrieval")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"HitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
@ -26,9 +26,9 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Hit testing completed successfully")
|
||||
@api.response(404, "Dataset not found")
|
||||
@api.response(400, "Invalid parameters")
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -21,6 +19,7 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -31,6 +30,7 @@ logger = logging.getLogger(__name__)
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -48,20 +48,22 @@ class DatasetsHitTestingBase:
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=cast(Account, current_user),
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
@ -24,11 +23,14 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def post(self, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs(**args)
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -59,8 +61,8 @@ class DatasetMetadataApi(Resource):
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
name = args["name"]
|
||||
|
||||
@ -79,6 +81,7 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -108,6 +111,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -128,16 +132,18 @@ class DocumentMetadataEditApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"operation_data", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
||||
@ -1,19 +1,15 @@
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
@ -24,11 +20,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, provider_id: str):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
@ -52,7 +48,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=current_user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
@ -125,27 +121,30 @@ class DatasourceOAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_datasource = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@console_ns.expect(parser_datasource)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||
)
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
args = parser_datasource.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
try:
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
@ -160,31 +159,39 @@ class DatasourceAuth(Resource):
|
||||
def get(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
parser_datasource_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@console_ns.expect(parser_datasource_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parser_datasource_delete.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
@ -192,23 +199,30 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_datasource_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@console_ns.expect(parser_datasource_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
args = parser_datasource_update.parse_args()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
@ -224,10 +238,10 @@ class DatasourceAuthListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
@ -237,29 +251,35 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
parser_datasource_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@console_ns.expect(parser_datasource_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_datasource_custom.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
@ -270,52 +290,63 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@console_ns.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_default.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_update_name = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@console_ns.expect(parser_update_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_update_name.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -12,9 +12,21 @@ from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
inputs: dict
|
||||
datasource_type: str
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -26,19 +38,10 @@ class DataSourceContentPreviewApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
args = Parser.model_validate(console_ns.payload)
|
||||
|
||||
inputs = args.inputs
|
||||
datasource_type = args.datasource_type
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
@ -47,6 +50,6 @@ class DataSourceContentPreviewApi(Resource):
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
credential_id=args.get("credential_id"),
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
return preview_content, 200
|
||||
|
||||
@ -66,29 +66,31 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@ -123,26 +125,28 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@ -13,7 +12,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
@ -27,9 +26,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument(
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
@ -38,7 +35,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -58,12 +55,12 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
@ -81,10 +78,12 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
|
||||
@ -23,7 +23,7 @@ from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -33,16 +33,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
@ -206,10 +208,11 @@ class RagPipelineVariableApi(Resource):
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user