mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 19:55:06 +08:00
Compare commits
953 Commits
1.9.2
...
feat/trigg
| Author | SHA1 | Date | |
|---|---|---|---|
| 2fa6684c4d | |||
| fe6538b08d | |||
| 1bbb9d6644 | |||
| 5c06e285ec | |||
| 19c92fd670 | |||
| 6026bd873b | |||
| 1369119a0c | |||
| b76e17b25d | |||
| 87439b8fec | |||
| 08034532f6 | |||
| c5f47ebccd | |||
| ca7794305b | |||
| 6744306818 | |||
| fd255e81e1 | |||
| 3ff14ccc89 | |||
| 09d31d1263 | |||
| 9c30f16e4b | |||
| c31933c163 | |||
| 574eb1a10a | |||
| e6ac783fc3 | |||
| 47dc26f011 | |||
| 689a75f44a | |||
| 123bb3ec08 | |||
| 90f77282e3 | |||
| a25f469bde | |||
| 044ee7ef54 | |||
| 2725f28fa8 | |||
| 36ad784251 | |||
| 0a39e5c092 | |||
| 9169a5e35b | |||
| 5208867ccc | |||
| bfdcb79e19 | |||
| c37cce000f | |||
| 6d3fb9b769 | |||
| c04913ecf8 | |||
| 3a84a64c32 | |||
| b344d4add1 | |||
| 405a4ec9f8 | |||
| edc7ccc795 | |||
| 8bb11a588c | |||
| 44f451bd7d | |||
| 35d914e755 | |||
| 9c37f8c1cb | |||
| 9de0e3c3a7 | |||
| 707c94f86e | |||
| 81afd087f6 | |||
| 0f952f328f | |||
| 50619fba0a | |||
| aad31bb703 | |||
| 7484a020e1 | |||
| 186828c13a | |||
| 203fb95391 | |||
| a94e650ffd | |||
| c9798f6425 | |||
| 20ecf7f1d0 | |||
| 9dcb780fcb | |||
| 1cb7b09933 | |||
| 2c62a77cf4 | |||
| b9bc48d8dd | |||
| ed234e311b | |||
| 00fdd06179 | |||
| 9843fec393 | |||
| 62fbc90389 | |||
| aa4cabdeb5 | |||
| eea713b668 | |||
| f19a21da11 | |||
| fc62538a94 | |||
| 7994144df7 | |||
| e153c483b6 | |||
| 422bb4d4bb | |||
| 87a80d7613 | |||
| 7401792063 | |||
| e91105ca87 | |||
| 79e46c8a81 | |||
| 7658c92cf9 | |||
| 37903722fe | |||
| f4c82d0010 | |||
| fe50093c18 | |||
| 4317af1e90 | |||
| 85a5c78b80 | |||
| 9d7b47c784 | |||
| e4c6ed9c60 | |||
| fcfade4778 | |||
| 000e8bd12b | |||
| ed8da2c760 | |||
| fb6dc14e9b | |||
| 77e6e98234 | |||
| 61a0fcc2ea | |||
| fb3699ec5e | |||
| f627348b11 | |||
| 4601be8b67 | |||
| 87fb9a6b69 | |||
| 97a2e2ec2e | |||
| 68d357d7f6 | |||
| cc4d4adfb9 | |||
| 6a08623949 | |||
| f0127ffc9a | |||
| f1e513830c | |||
| 7de533a643 | |||
| a103ad3ee7 | |||
| f65d5a9761 | |||
| 6e0a5f5bbd | |||
| 22f858152f | |||
| 052127c473 | |||
| 7a4be5c0d2 | |||
| a6208feed8 | |||
| 775d2e14fc | |||
| c8f55549d7 | |||
| 744b287e67 | |||
| c0fc5d98f0 | |||
| 132a86dcb3 | |||
| 08ea79d730 | |||
| f31b821cc0 | |||
| 34be16874f | |||
| e9738b891f | |||
| 9f59baed10 | |||
| ce56286329 | |||
| 6e76f2aff2 | |||
| 49edd58722 | |||
| 6a28aee13e | |||
| 829796514a | |||
| 79c70d09c9 | |||
| ef1db35f80 | |||
| b9bb97887b | |||
| f9c67621ca | |||
| e29e8e3180 | |||
| 7a81e720d4 | |||
| 55600c0eb1 | |||
| 35e41d7d68 | |||
| b610cf9a11 | |||
| c8e9edc024 | |||
| 471cd760d7 | |||
| 7f48c57edf | |||
| 6569801162 | |||
| 9dd83f50a7 | |||
| 59c56b1b0d | |||
| 7df6d9f1aa | |||
| 94cd2de940 | |||
| 587f83bc34 | |||
| 3c23375607 | |||
| d81b2e6820 | |||
| 56047f638f | |||
| 9c01d3e775 | |||
| 8315e0c74b | |||
| c85c87f3da | |||
| eaa02e3d55 | |||
| 0219222a60 | |||
| dba659b220 | |||
| ee6458768e | |||
| ed3d02dc6d | |||
| 95471b1188 | |||
| 6190cfbfd8 | |||
| 3cc6690356 | |||
| 6507263b28 | |||
| 8fa0bb48df | |||
| 11f2f95103 | |||
| 2abbc14703 | |||
| b2b2816ade | |||
| 637a675681 | |||
| 4461df1bd9 | |||
| 085ada86e6 | |||
| f59d430219 | |||
| c415e5b893 | |||
| 67b6b3612c | |||
| f7f6b4a8b0 | |||
| 229b0e190f | |||
| 09d412cf2a | |||
| 2842cbf1e1 | |||
| e2543bcf30 | |||
| 3f75aa6848 | |||
| 57719f3ce9 | |||
| 45677ac57c | |||
| 4eacbf37ff | |||
| ef256ac276 | |||
| 2733e04039 | |||
| e49ec82258 | |||
| cf301eb1d9 | |||
| 98b9ba2b2e | |||
| 2126c64468 | |||
| 271a1b4f98 | |||
| 9be3c62c04 | |||
| 04bfa235a9 | |||
| 3b37ae1b4e | |||
| c1cb93cd26 | |||
| 75fa161c46 | |||
| d6d82cff33 | |||
| 5c266fecf9 | |||
| 7244978b24 | |||
| 623021dcff | |||
| 5af165fce9 | |||
| 9503fafc53 | |||
| 99fac21bdb | |||
| bc95678c5e | |||
| 3f34f38635 | |||
| 30f771369b | |||
| 20bd059a6c | |||
| f5eb406394 | |||
| cbebac1d45 | |||
| 030da43ae3 | |||
| b7f1394403 | |||
| ceb6a09387 | |||
| 14ad800967 | |||
| 34d1f86f76 | |||
| b9b9f8eae3 | |||
| 0de8596afe | |||
| 4dbd26ff66 | |||
| d018ef9033 | |||
| 979c985804 | |||
| 291e9a3aee | |||
| 5861ca773e | |||
| eb3b5f751a | |||
| 9bbfbf1c5f | |||
| 8cbd124b80 | |||
| d137d0eed0 | |||
| 58c5db3b00 | |||
| 8750796f9f | |||
| 7d4bb45f94 | |||
| db744444f2 | |||
| 41be581594 | |||
| b25d379ef4 | |||
| e1e95f7ccd | |||
| edd50420ec | |||
| 9af8fe085b | |||
| ed6bb121bb | |||
| 4635b99153 | |||
| 282fde9a04 | |||
| e9078eedbd | |||
| 501698d844 | |||
| dd089b1b21 | |||
| 6260a1a28c | |||
| 8bc5035624 | |||
| 2dbfd9ea5a | |||
| 08e61d76d6 | |||
| 447127cee4 | |||
| 49ebbd05b5 | |||
| defea962f6 | |||
| b3866288e0 | |||
| 20ad5b7ac2 | |||
| bed2ce69bb | |||
| 4d37d61851 | |||
| 8a48db6d0d | |||
| ff0f645e54 | |||
| 6e0765fbaf | |||
| 1d03e0e9fc | |||
| cac60a25bb | |||
| 57c65ec625 | |||
| ffc3c61d00 | |||
| aa3b16a136 | |||
| 6e0b408dd5 | |||
| a1c0bd7a1c | |||
| be9eeff6c2 | |||
| fd7c4e8a6d | |||
| 41e549af14 | |||
| ca9d92b1e5 | |||
| b7360140ee | |||
| c71f7c7613 | |||
| c905c47775 | |||
| 0607db41e5 | |||
| 48b1829b14 | |||
| 6767a8f72c | |||
| 4ca7ba000c | |||
| 1e477af05f | |||
| 9b5e5f0f50 | |||
| fb12f31df2 | |||
| f260627660 | |||
| db2c6678e4 | |||
| bc3421add8 | |||
| 1e9142c213 | |||
| 61d8809a0f | |||
| 82890fe38e | |||
| 7dc7c8af98 | |||
| addebc465a | |||
| d37cc9f9c8 | |||
| 5ab315aeaf | |||
| 0db082f6d0 | |||
| c94dc52310 | |||
| f092bc1912 | |||
| bebcbfd80e | |||
| dfc5e3609d | |||
| fa5765ae82 | |||
| 852d851996 | |||
| 0b599b44b0 | |||
| f06dc3ef90 | |||
| 23b49b8304 | |||
| 9e97248ede | |||
| d532b06310 | |||
| 07a2281730 | |||
| 42385f3ffa | |||
| c597234374 | |||
| f9df61e648 | |||
| 6e76e02dba | |||
| dc24450e29 | |||
| 8bcecce627 | |||
| 66cb963df3 | |||
| 5c95c77604 | |||
| a264a609db | |||
| 3a876fd437 | |||
| 13bc68a646 | |||
| b41538d8c7 | |||
| 720480d05e | |||
| 71b1af69c5 | |||
| 3de73f07c6 | |||
| 18fd79fbe6 | |||
| c16421df27 | |||
| 0d686fc6ae | |||
| 0caeaf6e5c | |||
| db352c0a18 | |||
| bf7b18d442 | |||
| 7d56ca5294 | |||
| 0b1015e221 | |||
| 3395297c3e | |||
| e60a7c7143 | |||
| 96f0d648fa | |||
| c4996f9563 | |||
| 0e62a66cc2 | |||
| ff32dff163 | |||
| 543c5236e7 | |||
| 341b3ae7c9 | |||
| f01907aac2 | |||
| a7c855cab8 | |||
| 29afc0657d | |||
| 850c5fec32 | |||
| d9860b8907 | |||
| b1f79c34d1 | |||
| 96a461646e | |||
| 5df94fd866 | |||
| dc1ae57dc6 | |||
| d6bd2a9bdb | |||
| e074ba84d1 | |||
| c9eed67cf6 | |||
| 0ded6303c1 | |||
| 1335be8d60 | |||
| c79d75b32d | |||
| f18054847e | |||
| b2b81f3822 | |||
| 90753b2782 | |||
| c05fa9963a | |||
| b6e0abadab | |||
| 9de7a7d48f | |||
| 43bcf40f80 | |||
| 29cddc449f | |||
| dfed14ba67 | |||
| f06025a342 | |||
| 440262a51b | |||
| d705fece9d | |||
| d08cc48368 | |||
| b94ad084c3 | |||
| 24fb95b050 | |||
| 49fca63927 | |||
| ce5fe86430 | |||
| 9453148233 | |||
| 1857d0e53f | |||
| 666586b59c | |||
| ae422c2628 | |||
| d933116e46 | |||
| 8a2851551a | |||
| a2fe4a28c3 | |||
| 417ebd160b | |||
| 5cf4afd7b2 | |||
| 82be305680 | |||
| 03002f4971 | |||
| f7853f3b27 | |||
| 913d85302c | |||
| 33daedd7aa | |||
| cc219cc81c | |||
| 945295adc3 | |||
| 651cc81cfe | |||
| 3bd62f3fdf | |||
| e3484c8dc3 | |||
| eecbe533a1 | |||
| 4221e99362 | |||
| 5c69521973 | |||
| ffcaa67a56 | |||
| 64a070f6b0 | |||
| c61656c759 | |||
| 6d34e4e99b | |||
| d3a767364b | |||
| 1e7e8a8988 | |||
| f3c6d1ca1d | |||
| a715d5ac23 | |||
| 398c8117fe | |||
| f45c18ee35 | |||
| 6098dc0242 | |||
| 15c1db42dd | |||
| 29ec3c7d5c | |||
| a31c01f8d9 | |||
| 62753cdf13 | |||
| 4597ab4efb | |||
| 1dddcf1194 | |||
| c4c38a51d9 | |||
| 8a8c0703b1 | |||
| 1b74869b04 | |||
| f065504ed6 | |||
| 3f5485605f | |||
| 399bb522e0 | |||
| 9fffa9a996 | |||
| aee9a8366f | |||
| c3eec7ea8a | |||
| 4b4ec3438f | |||
| 9aa43c9165 | |||
| 4ae23ed0f9 | |||
| efe68d5aa6 | |||
| 1604db02b5 | |||
| e7192de9c0 | |||
| f822b38a00 | |||
| 5a5c7f38d1 | |||
| 42a9a88ae2 | |||
| aea3fc6281 | |||
| bcdc11396a | |||
| ecd1d44d23 | |||
| 6df786248c | |||
| 37e75f7791 | |||
| 7ada2385b3 | |||
| a77aab96f5 | |||
| 13af48800b | |||
| 6e7fb59638 | |||
| 863b4f8fe9 | |||
| 949ac9d930 | |||
| 06d1a2e2fd | |||
| d478f62b49 | |||
| 128bc2241d | |||
| b2730d680c | |||
| 52b180104a | |||
| 9cdb62da93 | |||
| 5af08edfda | |||
| 24fa5f33d7 | |||
| caf0bf34dd | |||
| 181a1ae7f3 | |||
| e18ecead2c | |||
| 36a26adab2 | |||
| bc2edf5107 | |||
| 50bbac5973 | |||
| 45b221659b | |||
| 16957f14f1 | |||
| 0d7dde0639 | |||
| 37805184d9 | |||
| df9932088f | |||
| d101a83be8 | |||
| 94ea289c75 | |||
| e2539e91eb | |||
| 77e9bae3ff | |||
| d99644237b | |||
| 5cb268e99b | |||
| f179b03d6e | |||
| 28fe58f3dd | |||
| 14acd05846 | |||
| ccce135bf5 | |||
| cb5607fc8c | |||
| 7f70d1de1c | |||
| c36173f5a9 | |||
| 7acbe981e2 | |||
| dd6ab7c68c | |||
| a1ea256e79 | |||
| 14942c9ee9 | |||
| b0b316ed48 | |||
| 871cfbd40c | |||
| 9a3ca0ce3b | |||
| c90df5c12c | |||
| f4acc78f66 | |||
| 3d5e2c5ca1 | |||
| 55bf9196dc | |||
| 18a52b4937 | |||
| 439727746c | |||
| 04b55177b5 | |||
| 2793ede875 | |||
| dc4801c014 | |||
| d5e2649608 | |||
| 4102f0bc9d | |||
| 25e4203cb1 | |||
| e1a3ead941 | |||
| 6251090893 | |||
| aa5e04b70e | |||
| 8ac25c29ee | |||
| f4517d667b | |||
| dc2481c805 | |||
| 8d7435a51b | |||
| bb28c718df | |||
| 1b7a5b6209 | |||
| 448622b4fd | |||
| e9dda03e8d | |||
| 8d3d177932 | |||
| f0af4d692a | |||
| 075173e67d | |||
| f02d575379 | |||
| 735ebf6c59 | |||
| 96f0b7abe3 | |||
| eb1686f04b | |||
| d4b5d9a02a | |||
| f87f77ce7b | |||
| 24619e74f6 | |||
| f5c1646f79 | |||
| e26d77e78c | |||
| 8e1e81732a | |||
| 801f8c1592 | |||
| fd4b234171 | |||
| dff536ab6d | |||
| a152ce45d3 | |||
| 6a164f8811 | |||
| a03ff39f3e | |||
| a6373e357a | |||
| 538b639bef | |||
| fe0457b257 | |||
| d5b228f234 | |||
| 1c2f95eeb6 | |||
| 81b3436ec4 | |||
| 3e4f2bcf14 | |||
| c7696964b9 | |||
| fb8ecf7b5a | |||
| e3c2345b21 | |||
| bfe0d14409 | |||
| c7498c3a11 | |||
| 5fba41688a | |||
| b63b9c32f7 | |||
| 65c6203ad7 | |||
| 3a18337129 | |||
| b6b433626e | |||
| 5d6b9b0cb1 | |||
| 6d09330f98 | |||
| 5df9afa91a | |||
| 30a341331f | |||
| 31cf4b6619 | |||
| dd0da3218c | |||
| 11c9219848 | |||
| b1ffd2ef2b | |||
| 86cf7952fb | |||
| d790d2b6bc | |||
| a711a8e759 | |||
| 8a18b6e13b | |||
| 95aeb61d7c | |||
| e8b0144cf7 | |||
| 2c8c1860ca | |||
| 5edfbd5305 | |||
| 4ceae655bd | |||
| 6ae76d108b | |||
| 9cc3cfb63e | |||
| 58e4c0793a | |||
| 80f2c1be67 | |||
| 8a5174d078 | |||
| d0f357a690 | |||
| fbe3df5658 | |||
| 21e3ef91eb | |||
| 3f116dc74b | |||
| 32731c4622 | |||
| 3c1f0e1aec | |||
| 685e48636d | |||
| 7c4edaa636 | |||
| 35867707d0 | |||
| 5b884d750f | |||
| bc0d5f4e41 | |||
| f20452622a | |||
| 6ba26cf7b5 | |||
| 7510e0654b | |||
| 564bb22d8b | |||
| 5e2d5f0d83 | |||
| d90ffbcf14 | |||
| 771cc72dcf | |||
| 04c91111e9 | |||
| 5a13daefdb | |||
| c033c05ec1 | |||
| 5b2f323a87 | |||
| b855d95430 | |||
| fe4b63210e | |||
| 84c09ec59d | |||
| 40e17ef801 | |||
| f1fcb92691 | |||
| 3865555113 | |||
| 95e46806a4 | |||
| c9c3d03878 | |||
| b28ec4be6e | |||
| 29d7023fae | |||
| 22f6c23780 | |||
| 548db29a47 | |||
| 1089c5bf04 | |||
| 559cf6583f | |||
| b04f92715c | |||
| 671aba6ab7 | |||
| beaeb30dcc | |||
| 56abca1f41 | |||
| 52d5f219e1 | |||
| d4516e942c | |||
| 1c17a16830 | |||
| 1f6ab13fc5 | |||
| 7344df87e5 | |||
| 29353bd7c2 | |||
| 7b6f5d6860 | |||
| 2ccb20bf3a | |||
| 34b7e5cbca | |||
| a595e2df06 | |||
| 729e0e9b1e | |||
| c03b790888 | |||
| 112b5f63dd | |||
| 334e5f19bf | |||
| 35bbf67175 | |||
| 9aec255ee9 | |||
| b07e80e6ae | |||
| ad2b910d73 | |||
| f28a7218cd | |||
| 4164e1191e | |||
| bd31c6f90b | |||
| 8f7bef9509 | |||
| 06c91fbcbd | |||
| dab4e521af | |||
| b20f61356c | |||
| 4ec23eea00 | |||
| 270fd9cb07 | |||
| c7c5e07d43 | |||
| c1ba83f0d4 | |||
| d71200ee32 | |||
| 16ac05ebd5 | |||
| ac77b9b735 | |||
| 0fa4b77ff8 | |||
| 6773dda657 | |||
| bf42386c5b | |||
| 90fc06a494 | |||
| 8dfe693529 | |||
| d65d27a6bb | |||
| e6a6bde8e2 | |||
| c7d0a7be04 | |||
| e0f1b03cf0 | |||
| 902737b262 | |||
| 429cd05a0f | |||
| 46e7e99c5a | |||
| d19ce15f3d | |||
| 49af7eb370 | |||
| 8e235dc92c | |||
| 3b3963b055 | |||
| 378c2afcd3 | |||
| d709f20e1f | |||
| 99d9657af8 | |||
| 62efdd7f7a | |||
| ebcf98c137 | |||
| 7560e2427d | |||
| 920a608e5d | |||
| 4dfb8b988c | |||
| af6dae3498 | |||
| ee21b4d435 | |||
| 654adccfbf | |||
| b283a2b3d9 | |||
| cce729916a | |||
| 4f8bf97935 | |||
| ba88c7b25b | |||
| 0ec5d53e5b | |||
| f3b415c095 | |||
| 6fb657a89e | |||
| 90240cb6db | |||
| cca48f07aa | |||
| beff639c3d | |||
| 00359830c2 | |||
| f23e098b9a | |||
| 42f75b6602 | |||
| 4f65cc312d | |||
| 854a091f82 | |||
| 63dbc7c63d | |||
| a4e80640fe | |||
| fe0a139c89 | |||
| ac2616545b | |||
| c9e7922a14 | |||
| 12a7402291 | |||
| 33d7b48e49 | |||
| ee89e9eb2f | |||
| e793f9e871 | |||
| 18b02370a2 | |||
| d53399e546 | |||
| 622d12137a | |||
| bae8e44b32 | |||
| 24b5387fd1 | |||
| 0c65824cad | |||
| 31c9d9da3f | |||
| 8f854e6a45 | |||
| 75b3f5ac5a | |||
| 323e183775 | |||
| 380ef52331 | |||
| b8862293b6 | |||
| 85f1cf1d90 | |||
| 1d4e36d58f | |||
| 90ae5e5865 | |||
| 755fb96a33 | |||
| b8ca480b07 | |||
| 8a5fbf183b | |||
| 91318d3d04 | |||
| a33d04d1ac | |||
| 02222752f0 | |||
| 04d94e3337 | |||
| b98c36db48 | |||
| d05d11e67f | |||
| 3370736e09 | |||
| cc5a315039 | |||
| 6ea10cdaaf | |||
| 9643fa1c9a | |||
| 937a58d0dd | |||
| d9faa1329a | |||
| fec09e7ed3 | |||
| 31b15b492e | |||
| f96bd4eb18 | |||
| a4109088c9 | |||
| f827e8e1b7 | |||
| 82f2f76dc4 | |||
| e6a44a0860 | |||
| 604651873e | |||
| 9114881623 | |||
| 080cdda4fa | |||
| 32f4d1af8b | |||
| 1bfa8e6662 | |||
| 7c97ea4a9e | |||
| bea11b08d7 | |||
| 8547032a87 | |||
| 43574c852d | |||
| 5ecc006805 | |||
| 15413108f0 | |||
| 831c888b84 | |||
| f0ed09a8d4 | |||
| a80f30f9ef | |||
| fd2f0df097 | |||
| d72a3e1879 | |||
| 4a6903fdb4 | |||
| 8106df1d7d | |||
| 5e3e6b0bd8 | |||
| a06d2892f8 | |||
| e377e90666 | |||
| 19cc67561b | |||
| 92f2ca1866 | |||
| 1949074e2f | |||
| 1c0068e95b | |||
| 4b43196295 | |||
| 2c3cf9a25e | |||
| 67fbfc0b8f | |||
| 6e6198c64e | |||
| 6b677c16ce | |||
| 973b937ba5 | |||
| 48597ef193 | |||
| ffbc007f82 | |||
| 8fc88f8cbf | |||
| a4b932c78b | |||
| 2ff4af8ce3 | |||
| 6795015d00 | |||
| b100ce15cd | |||
| 3edf1e2f59 | |||
| 4d49db0ff9 | |||
| 7da22e4915 | |||
| 8d4a9df6b1 | |||
| f620e78b20 | |||
| 8df80781d9 | |||
| edec065fee | |||
| 0fe529c3aa | |||
| bcfdd07f85 | |||
| a9a118aaf9 | |||
| 60c86dd8d1 | |||
| 8feef2c1a9 | |||
| 4ba99db88c | |||
| b4801adfbd | |||
| 08e8f8676e | |||
| 2dca0c20db | |||
| 6f57aa3f53 | |||
| 1aafe915e4 | |||
| 6d4d25ee6f | |||
| 6b94d30a5f | |||
| 1a9798c559 | |||
| 764436ed8e | |||
| 2a1c5ff57b | |||
| cc4ba1a3a9 | |||
| d68a9f1850 | |||
| 4f460160d2 | |||
| d5ff89f6d3 | |||
| 452588dded | |||
| aef862d9ce | |||
| 896f3252b8 | |||
| 6853a699e1 | |||
| cd07eef639 | |||
| ef9a741781 | |||
| c5de91ba94 | |||
| bc1e6e011b | |||
| 906028b1fb | |||
| 034602969f | |||
| 4ca14bfdad | |||
| 59f56d8c94 | |||
| 63d26f0478 | |||
| eae65e55ce | |||
| 0edf06329f | |||
| 6943a379c9 | |||
| e49534b70c | |||
| 344616ca2f | |||
| 0e287a9c93 | |||
| 8141f53af5 | |||
| 5a6cb0d887 | |||
| 26e7677595 | |||
| 814b0e1fe8 | |||
| a173dc5c9d | |||
| a567facf2b | |||
| e76d80defe | |||
| 4a17025467 | |||
| bd1fcd3525 | |||
| 0cb0cea167 | |||
| ee68a685a7 | |||
| c78bd492af | |||
| 6857bb4406 | |||
| dcf3ee6982 | |||
| 76850749e4 | |||
| 91e5e33440 | |||
| 11e55088c9 | |||
| 57c0bc9fb6 | |||
| c3ebb22a4b | |||
| 1562d00037 | |||
| e9e843b27d | |||
| ec33b9908e | |||
| 67004368d9 | |||
| 94ecbd44e4 | |||
| ba76312248 | |||
| 50bff270b6 | |||
| bd5cf1c272 | |||
| d22404994a | |||
| 9898730cc5 | |||
| b0f1e55a87 | |||
| 6566824807 | |||
| 9249a2af0d | |||
| 112fc3b1d1 | |||
| 37299b3bd7 | |||
| 8f65ce995a | |||
| 4a743e6dc1 | |||
| 07dda61929 | |||
| 0d8438ef40 | |||
| 96bb638969 | |||
| e74962272e | |||
| 5a15419baf | |||
| e8403977b9 | |||
| add2ca85f2 | |||
| fbb7b02e90 | |||
| 249b62c9de | |||
| b433322e8d | |||
| 1c8850fc95 | |||
| dc16f1b65a | |||
| ff30395dc1 | |||
| 8e600f3302 | |||
| 5a1e0a8379 | |||
| 2a3ce6baa9 | |||
| 01b2f9cff6 | |||
| ac38614171 | |||
| eb95c5cd07 | |||
| a799b54b9e | |||
| 98ba0236e6 | |||
| b6c552df07 | |||
| e2827e475d | |||
| 58cbd337b5 | |||
| a91e59d544 | |||
| 814787677a | |||
| 85caa5bd0c | |||
| e04083fc0e | |||
| cf532e5e0d | |||
| c097fc2c48 | |||
| 0371d71409 | |||
| 81ef7343d4 | |||
| 8e4b59c90c | |||
| 68f73410fc | |||
| 88af8ed374 | |||
| 015f82878e | |||
| 3874e58dc2 | |||
| 9f8c159583 | |||
| d8f6f9ce19 | |||
| eab03e63d4 | |||
| 461829274a | |||
| e751c0c535 | |||
| 1fffc79c32 | |||
| 83fab4bc19 | |||
| f60e28d2f5 | |||
| a62d7aa3ee | |||
| cc84a45244 | |||
| 5cf3d24018 | |||
| 4bdbe617fe | |||
| 33c867fd8c | |||
| 2013ceb9d2 | |||
| 7120c6414c | |||
| 5ce7b2d98d | |||
| cb82198271 | |||
| 5e5ffaa416 | |||
| 4b253e1f73 | |||
| dd929dbf0e | |||
| 97a9d34e96 | |||
| 602070ec9c | |||
| afd8989150 | |||
| 694197a701 | |||
| 2f08306695 | |||
| 6acc77d86d | |||
| 5ddd5e49ee | |||
| 72f9e77368 | |||
| a46c9238fa | |||
| 87120ad4ac | |||
| 7544b5ec9a | |||
| ff4a62d1e7 | |||
| 41daa51988 | |||
| d522350c99 | |||
| 1d1bb9451e | |||
| 1fce1a61d4 | |||
| 883a6caf96 | |||
| a239c39f09 | |||
| e925a8ab99 | |||
| bccaf939e6 | |||
| 676648e0b3 | |||
| 4ae19e6dde | |||
| 4d0ff5c281 | |||
| 327b354cc2 | |||
| 6d307cc9fc | |||
| adc7134af5 | |||
| 10f19cd0c2 | |||
| 9ed45594c6 | |||
| c138f4c3a6 | |||
| a35be05790 | |||
| 60b5ed8e5d | |||
| d8ddbc4d87 | |||
| 19c0fc85e2 | |||
| a58df35ead | |||
| 9789bd02d8 | |||
| d94e54923f | |||
| 64c7be59b7 | |||
| 89ad6ad902 | |||
| 4f73bc9693 | |||
| add6b79231 | |||
| c90dad566f | |||
| 5cbe6bf8f8 | |||
| 4ef6ff217e | |||
| 87abfbf515 | |||
| 73e65fd838 | |||
| e53edb0fc2 | |||
| 17908fbf6b | |||
| 3dae108f84 | |||
| 5bbf685035 | |||
| a63d1e87b1 | |||
| 7129de98cd | |||
| 2984dbc0df | |||
| 392db7f611 | |||
| 5a427b8daa | |||
| 18f2e6f166 | |||
| e78903302f | |||
| 4084ade86c | |||
| 6b0d919dbd | |||
| a7b558b38b | |||
| 6aed7e3ff4 | |||
| 8e93a8a2e2 | |||
| e38a86e37b | |||
| 392e3530bf | |||
| 833c902b2b | |||
| 6eaea64b3f | |||
| 5303b50737 | |||
| 6acbcfe679 | |||
| 16ef5ebb97 | |||
| acfb95f9c2 | |||
| aacea166d7 | |||
| f7bb3b852a | |||
| d4ff1e031a | |||
| 6a3d135d49 | |||
| 5c4bf7aabd | |||
| e9c7dc7464 | |||
| 74ad21b145 | |||
| f214eeb7b1 | |||
| ae25f90f34 |
@ -11,7 +11,7 @@
|
||||
"nodeGypDependencies": true,
|
||||
"version": "lts"
|
||||
},
|
||||
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
|
||||
"ghcr.io/devcontainers-extra/features/npm-package:1": {
|
||||
"package": "typescript",
|
||||
"version": "latest"
|
||||
},
|
||||
|
||||
@ -6,7 +6,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -2,6 +2,8 @@ name: autofix.ci
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
|
||||
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@ -103,6 +103,11 @@ jobs:
|
||||
run: |
|
||||
pnpm run lint
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@ -6,6 +6,9 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# *db files
|
||||
*.db
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
@ -97,6 +100,7 @@ __pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat-schedule.db
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
@ -234,4 +238,7 @@ scripts/stress-test/reports/
|
||||
|
||||
# mcp
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
.serena/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
|
||||
9
.vscode/launch.json.template
vendored
9
.vscode/launch.json.template
vendored
@ -8,8 +8,7 @@
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
"FLASK_ENV": "development"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
@ -28,9 +27,7 @@
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
@ -40,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace",
|
||||
"dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
10
README.md
10
README.md
@ -63,7 +63,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
|
||||
> - CPU >= 2 Core
|
||||
> - RAM >= 4 GiB
|
||||
|
||||
</br>
|
||||
<br/>
|
||||
|
||||
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
|
||||
@ -109,15 +109,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
||||
|
||||
## Using Dify
|
||||
|
||||
- **Cloud </br>**
|
||||
- **Cloud <br/>**
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
|
||||
- **Self-hosting Dify Community Edition</br>**
|
||||
- **Self-hosting Dify Community Edition<br/>**
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations</br>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
- **Dify for enterprise / organizations<br/>**
|
||||
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
|
||||
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
|
||||
@ -27,6 +27,9 @@ FILES_URL=http://localhost:5001
|
||||
# Example: INTERNAL_FILES_URL=http://api:5001
|
||||
INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# TRIGGER URL
|
||||
TRIGGER_URL=http://localhost:5001
|
||||
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
@ -156,6 +159,9 @@ SUPABASE_URL=your-server-url
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
|
||||
# Provide the registrable domain (e.g. example.com); leading dots are optional.
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
@ -368,6 +374,12 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Comma-separated list of file extensions blocked from upload for security reasons.
|
||||
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
|
||||
# Empty by default to allow all file types.
|
||||
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
|
||||
UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
@ -457,6 +469,9 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@ -512,7 +527,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
# Workflow log cleanup configuration
|
||||
# Enable automatic cleanup of workflow run logs to manage database size
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=true
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=false
|
||||
# Number of days to retain workflow run logs (default: 30 days)
|
||||
WORKFLOW_LOG_RETENTION_DAYS=30
|
||||
# Batch size for workflow log cleanup operations (default: 100)
|
||||
@ -534,6 +549,12 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
@ -605,3 +626,9 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
62
api/AGENTS.md
Normal file
62
api/AGENTS.md
Normal file
@ -0,0 +1,62 @@
|
||||
# Agent Skill Index
|
||||
|
||||
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Platform Foundations
|
||||
|
||||
- **[Infrastructure Overview](agent_skills/infra.md)**\
|
||||
When to read this:
|
||||
|
||||
- You need to understand where a feature belongs in the architecture.
|
||||
- You’re wiring storage, Redis, vector stores, or OTEL.
|
||||
- You’re about to add CLI commands or async jobs.\
|
||||
What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
|
||||
|
||||
- **[Coding Style](agent_skills/coding_style.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re writing or reviewing backend code and need the authoritative checklist.
|
||||
- You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
|
||||
- You want the exact lint/type/test commands used in PRs.\
|
||||
Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Plugin & Extension Development
|
||||
|
||||
- **[Plugin Systems](agent_skills/plugin.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re building or debugging a marketplace plugin.
|
||||
- You need to know how manifests, providers, daemons, and migrations fit together.\
|
||||
What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
|
||||
|
||||
- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
|
||||
When to read this:
|
||||
|
||||
- You must integrate OAuth for a plugin or datasource.
|
||||
- You’re handling credential encryption or refresh flows.\
|
||||
Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Workflow Entry & Execution
|
||||
|
||||
- **[Trigger Concepts](agent_skills/trigger.md)**\
|
||||
When to read this:
|
||||
- You’re debugging why a workflow didn’t start.
|
||||
- You’re adding a new trigger type or hook.
|
||||
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
|
||||
Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Additional Notes for Agents
|
||||
|
||||
- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
|
||||
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
|
||||
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
|
||||
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
|
||||
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
|
||||
@ -15,7 +15,11 @@ FROM base AS packages
|
||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
@ -49,7 +53,9 @@ RUN \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
curl nodejs \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install fonts to support the use of tools like pypdfium2
|
||||
|
||||
@ -80,7 +80,7 @@
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
115
api/agent_skills/coding_style.md
Normal file
115
api/agent_skills/coding_style.md
Normal file
@ -0,0 +1,115 @@
|
||||
## Linter
|
||||
|
||||
- Always follow `.ruff.toml`.
|
||||
- Run `uv run ruff check --fix --unsafe-fixes`.
|
||||
- Keep each line under 100 characters (including spaces).
|
||||
|
||||
## Code Style
|
||||
|
||||
- `snake_case` for variables and functions.
|
||||
- `PascalCase` for classes.
|
||||
- `UPPER_CASE` for constants.
|
||||
|
||||
## Rules
|
||||
|
||||
- Use Pydantic v2 standard.
|
||||
- Use `uv` for package management.
|
||||
- Do not override dunder methods like `__init__`, `__iadd__`, etc.
|
||||
- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
|
||||
- Prefer simple functions over classes for lightweight helpers.
|
||||
- Keep files below 800 lines; split when necessary.
|
||||
- Keep code readable—no clever hacks.
|
||||
- Never use `print`; log with `logger = logging.getLogger(__name__)`.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
- Mirror the project’s layered architecture: controller → service → core/domain.
|
||||
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
|
||||
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
|
||||
|
||||
## SQLAlchemy Patterns
|
||||
|
||||
- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
|
||||
|
||||
- Open sessions with context managers:
|
||||
|
||||
```python
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
|
||||
|
||||
- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
|
||||
|
||||
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
|
||||
|
||||
## Storage & External IO
|
||||
|
||||
- Access storage via `extensions.ext_storage.storage`.
|
||||
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
|
||||
- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
|
||||
|
||||
## Pydantic Usage
|
||||
|
||||
- Define DTOs with Pydantic v2 models and forbid extras by default.
|
||||
|
||||
- Use `@field_validator` / `@model_validator` for domain rules.
|
||||
|
||||
- Example:
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
|
||||
|
||||
class TriggerConfig(BaseModel):
|
||||
endpoint: HttpUrl
|
||||
secret: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("secret")
|
||||
def ensure_secret_prefix(cls, value: str) -> str:
|
||||
if not value.startswith("dify_"):
|
||||
raise ValueError("secret must start with dify_")
|
||||
return value
|
||||
```
|
||||
|
||||
## Generics & Protocols
|
||||
|
||||
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
|
||||
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
|
||||
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
|
||||
|
||||
## Error Handling & Logging
|
||||
|
||||
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
|
||||
- Declare `logger = logging.getLogger(__name__)` at module top.
|
||||
- Include tenant/app/workflow identifiers in log context.
|
||||
- Log retryable events at `warning`, terminal failures at `error`.
|
||||
|
||||
## Tooling & Checks
|
||||
|
||||
- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
|
||||
- Type checks: `uv run --directory api --dev basedpyright`.
|
||||
- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||
- Run all of the above before submitting your work.
|
||||
|
||||
## Controllers & Services
|
||||
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
|
||||
- Document non-obvious behaviour with concise comments.
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
- Use `configs.dify_config` for configuration—never read environment variables directly.
|
||||
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
|
||||
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
|
||||
- Keep experimental scripts under `dev/`; do not ship them in production builds.
|
||||
96
api/agent_skills/infra.md
Normal file
96
api/agent_skills/infra.md
Normal file
@ -0,0 +1,96 @@
|
||||
## Configuration
|
||||
|
||||
- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
|
||||
- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
|
||||
- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
|
||||
- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
|
||||
|
||||
## Dependencies
|
||||
|
||||
- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
|
||||
- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
|
||||
- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
|
||||
|
||||
## Storage & Files
|
||||
|
||||
- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
|
||||
- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
|
||||
- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
|
||||
- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
|
||||
|
||||
## Redis & Shared State
|
||||
|
||||
- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
|
||||
- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
|
||||
|
||||
## Models
|
||||
|
||||
- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
|
||||
- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
|
||||
- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
|
||||
- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
|
||||
|
||||
## Vector Stores
|
||||
|
||||
- Vector client implementations live in `core/rag/datasource/vdb/<provider>`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
|
||||
- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
|
||||
- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
|
||||
- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
|
||||
|
||||
## Observability & OTEL
|
||||
|
||||
- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
|
||||
- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
|
||||
- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
|
||||
- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
|
||||
|
||||
## Ops Integrations
|
||||
|
||||
- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
|
||||
- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
|
||||
- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
|
||||
|
||||
## Controllers, Services, Core
|
||||
|
||||
- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
|
||||
- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
|
||||
- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
|
||||
|
||||
## Plugins, Tools, Providers
|
||||
|
||||
- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
|
||||
- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
|
||||
- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
|
||||
- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
|
||||
- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
|
||||
- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
|
||||
|
||||
## Async Workloads
|
||||
|
||||
see `agent_skills/trigger.md` for more detailed documentation.
|
||||
|
||||
- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
|
||||
- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
|
||||
- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
|
||||
- Generate migrations with `uv run --project api flask db revision --autogenerate -m "<summary>"`, then review the diff; never hand-edit the database outside Alembic.
|
||||
- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
|
||||
- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask <command>`.
|
||||
- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
|
||||
- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
|
||||
- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
|
||||
- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
|
||||
|
||||
## When You Add Features
|
||||
|
||||
- Check for an existing helper or service before writing a new util.
|
||||
- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
|
||||
- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
|
||||
- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.
|
||||
1
api/agent_skills/plugin.md
Normal file
1
api/agent_skills/plugin.md
Normal file
@ -0,0 +1 @@
|
||||
// TBD
|
||||
1
api/agent_skills/plugin_oauth.md
Normal file
1
api/agent_skills/plugin_oauth.md
Normal file
@ -0,0 +1 @@
|
||||
// TBD
|
||||
53
api/agent_skills/trigger.md
Normal file
53
api/agent_skills/trigger.md
Normal file
@ -0,0 +1,53 @@
|
||||
## Overview
|
||||
|
||||
Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
|
||||
|
||||
## Trigger nodes
|
||||
|
||||
- `UserInput`
|
||||
- `Trigger Webhook`
|
||||
- `Trigger Schedule`
|
||||
- `Trigger Plugin`
|
||||
|
||||
### UserInput
|
||||
|
||||
Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
|
||||
|
||||
1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
|
||||
1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
|
||||
1. For its detailed implementation, please refer to `core/workflow/nodes/start`
|
||||
|
||||
### Trigger Webhook
|
||||
|
||||
Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
|
||||
|
||||
Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
|
||||
|
||||
### Trigger Schedule
|
||||
|
||||
`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
|
||||
|
||||
To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
|
||||
|
||||
### Trigger Plugin
|
||||
|
||||
`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
|
||||
|
||||
1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
|
||||
1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
|
||||
|
||||
A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
|
||||
|
||||
## Worker Pool / Async Task
|
||||
|
||||
All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
|
||||
|
||||
The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
|
||||
|
||||
## Debug Strategy
|
||||
|
||||
Dify divided users into 2 groups: builders / end users.
|
||||
|
||||
Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
|
||||
|
||||
A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.
|
||||
21
api/app.py
21
api/app.py
@ -1,7 +1,7 @@
|
||||
import sys
|
||||
|
||||
|
||||
def is_db_command():
|
||||
def is_db_command() -> bool:
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
@ -13,23 +13,12 @@ if is_db_command():
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
||||
@ -15,12 +15,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -321,6 +321,8 @@ def migrate_knowledge_vector_database():
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
if not datasets.items:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
@ -1227,6 +1229,55 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
@ -1420,7 +1471,10 @@ def setup_datasource_oauth_client(provider, client_params):
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
def transform_datasource_credentials():
|
||||
@click.option(
|
||||
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
|
||||
)
|
||||
def transform_datasource_credentials(environment: str):
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
@ -1431,9 +1485,14 @@ def transform_datasource_credentials():
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
jina_plugin_unique_identifier = None
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
@ -1599,7 +1658,7 @@ def transform_datasource_credentials():
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
provider="jinareader",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
|
||||
@ -174,6 +174,33 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
)
|
||||
|
||||
|
||||
class AsyncWorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for async workflow
|
||||
"""
|
||||
|
||||
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
|
||||
description="Granularity for async workflow scheduler, "
|
||||
"sometime, few users could block the queue due to some time-consuming tasks, "
|
||||
"to avoid this, workflow can be suspended if needed, to achieve"
|
||||
"this, a time-based checker is required, every granularity seconds, "
|
||||
"the checker will check the workflow queue and suspend the workflow",
|
||||
default=120,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
@ -263,6 +290,8 @@ class EndpointConfig(BaseSettings):
|
||||
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
|
||||
)
|
||||
|
||||
TRIGGER_URL: str = Field(description="Template url for triggers", default="http://localhost:5001")
|
||||
|
||||
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
@ -331,12 +360,42 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
|
||||
"Empty by default to allow all file types."
|
||||
),
|
||||
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
|
||||
"""
|
||||
Parse and return the blacklist as a set of lowercase extensions.
|
||||
Returns an empty set if no blacklist is configured.
|
||||
"""
|
||||
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
|
||||
return set()
|
||||
return {
|
||||
ext.strip().lower().strip(".")
|
||||
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
|
||||
if ext.strip()
|
||||
}
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
HTTP-related configurations for the application
|
||||
"""
|
||||
|
||||
COOKIE_DOMAIN: str = Field(
|
||||
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
|
||||
default="",
|
||||
)
|
||||
|
||||
API_COMPRESSION_ENABLED: bool = Field(
|
||||
description="Enable or disable gzip compression for HTTP responses",
|
||||
default=False,
|
||||
@ -915,6 +974,11 @@ class DataSetConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
|
||||
description="Maximum number of segments for dataset segments API (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
@ -990,6 +1054,44 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
default=True,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_INTERVAL: int = Field(
|
||||
description="Trigger provider refresh poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_BATCH_SIZE: int = Field(
|
||||
description="Max trigger subscriptions to process per tick",
|
||||
default=200,
|
||||
)
|
||||
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive credential refresh threshold in seconds",
|
||||
default=180,
|
||||
)
|
||||
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive subscription refresh threshold in seconds",
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@ -1088,7 +1190,7 @@ class AccountConfig(BaseSettings):
|
||||
|
||||
|
||||
class WorkflowLogConfig(BaseSettings):
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
|
||||
default=100, description="Batch size for workflow run log cleanup operations"
|
||||
@ -1107,12 +1209,21 @@ class SwaggerUIConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
@ -1131,6 +1242,7 @@ class FeatureConfig(
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
|
||||
@ -22,6 +22,11 @@ class WeaviateConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
|
||||
@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
|
||||
"""
|
||||
@ -41,3 +42,11 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@ -25,6 +25,12 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
||||
code = 415
|
||||
|
||||
|
||||
class BlockedFileExtensionError(BaseHTTPException):
|
||||
error_code = "file_extension_blocked"
|
||||
description = "The file extension is blocked for security reasons."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
|
||||
@ -66,6 +66,7 @@ from .app import (
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
@ -126,6 +127,7 @@ from .workspace import (
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
@ -196,6 +198,7 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
@ -203,5 +206,6 @@ __all__ = [
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
@ -16,6 +16,7 @@ from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@ -175,8 +176,10 @@ class AnnotationApi(Resource):
|
||||
api.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
@ -193,11 +196,14 @@ class AnnotationApi(Resource):
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import abort
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import func, or_
|
||||
@ -19,7 +17,7 @@ from fields.conversation_fields import (
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
@ -90,25 +88,17 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
@ -270,29 +260,21 @@ class ChatConversationApi(Resource):
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
|
||||
@ -11,6 +11,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
@ -206,13 +207,11 @@ class InstructionGenerateApi(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args["language"])), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
|
||||
@ -16,7 +16,6 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
@ -24,12 +23,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
@ -194,45 +192,6 @@ class MessageFeedbackApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class MessageAnnotationApi(Resource):
|
||||
@api.doc("create_message_annotation")
|
||||
@api.doc(description="Create message annotation")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@marshal_with(annotation_fields)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@api.doc("get_annotation_count")
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
@ -11,6 +9,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode, Message
|
||||
@ -56,26 +55,16 @@ WHERE
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -120,8 +109,11 @@ class DailyConversationStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
stmt = (
|
||||
sa.select(
|
||||
@ -134,18 +126,10 @@ class DailyConversationStatistic(Resource):
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
|
||||
)
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
if start_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
if end_datetime_utc:
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
@ -198,26 +182,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -273,26 +248,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -357,26 +323,17 @@ FROM
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -446,26 +403,17 @@ WHERE
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -525,26 +473,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -602,26 +541,17 @@ WHERE
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
|
||||
@ -16,9 +16,19 @@ from controllers.console.wraps import account_initialization_required, edit_perm
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.debug.event_selectors import (
|
||||
TriggerDebugEvent,
|
||||
TriggerDebugEventPoller,
|
||||
create_event_poller,
|
||||
select_trigger_debug_events,
|
||||
)
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
@ -37,6 +47,7 @@ from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LISTENING_RETRY_IN = 2000
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
@ -102,7 +113,18 @@ class DraftWorkflowApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Draft workflow synced successfully", workflow_fields)
|
||||
@api.response(
|
||||
200,
|
||||
"Draft workflow synced successfully",
|
||||
api.model(
|
||||
"SyncDraftWorkflowResponse",
|
||||
{
|
||||
"result": fields.String,
|
||||
"hash": fields.String,
|
||||
"updated_at": fields.String,
|
||||
},
|
||||
),
|
||||
)
|
||||
@api.response(400, "Invalid workflow configuration")
|
||||
@api.response(403, "Permission denied")
|
||||
@edit_permission_required
|
||||
@ -915,3 +937,234 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run")
|
||||
class DraftWorkflowTriggerRunApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_run")
|
||||
@api.doc(description="Poll for trigger events and execute full workflow when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerRunRequest",
|
||||
{
|
||||
"node_id": fields.String(required=True, description="Node ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trigger event received and workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_id = args["node_id"]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
poller: TriggerDebugEventPoller = create_event_poller(
|
||||
draft_workflow=draft_workflow,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
event: TriggerDebugEvent | None = None
|
||||
try:
|
||||
event = poller.poll()
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
workflow_args = dict(event.workflow_args)
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
return helper.compact_generate_response(
|
||||
AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
root_node_id=node_id,
|
||||
)
|
||||
)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run")
|
||||
class DraftWorkflowTriggerNodeApi(Resource):
|
||||
"""
|
||||
Single node debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run
|
||||
"""
|
||||
|
||||
@api.doc("poll_draft_workflow_trigger_node")
|
||||
@api.doc(description="Poll for trigger events and execute single node when event arrives")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.response(200, "Trigger event received and node executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
node_config = draft_workflow.get_node_config_by_id(node_id=node_id)
|
||||
if not node_config:
|
||||
raise ValueError("Node data not found for node %s", node_id)
|
||||
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
|
||||
event: TriggerDebugEvent | None = None
|
||||
# for schedule trigger, when run single node, just execute directly
|
||||
if node_type == NodeType.TRIGGER_SCHEDULE:
|
||||
event = TriggerDebugEvent(
|
||||
workflow_args={},
|
||||
node_id=node_id,
|
||||
)
|
||||
# for other trigger types, poll for the event
|
||||
else:
|
||||
try:
|
||||
poller: TriggerDebugEventPoller = create_event_poller(
|
||||
draft_workflow=draft_workflow,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=current_user.id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
event = poller.poll()
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
if not event:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
|
||||
raw_files = event.workflow_args.get("files")
|
||||
files = _parse_file(draft_workflow, raw_files if isinstance(raw_files, list) else None)
|
||||
try:
|
||||
node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=event.workflow_args.get("inputs") or {},
|
||||
account=current_user,
|
||||
query="",
|
||||
files=files,
|
||||
)
|
||||
return jsonable_encoder(node_execution)
|
||||
except Exception as e:
|
||||
logger.exception("Error running draft workflow trigger node")
|
||||
return jsonable_encoder(
|
||||
{"status": "error", "error": "An unexpected error occurred while running the node."}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/trigger/run-all")
|
||||
class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"""
|
||||
Full workflow debug - Polling API for trigger events
|
||||
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all
|
||||
"""
|
||||
|
||||
@api.doc("draft_workflow_trigger_run_all")
|
||||
@api.doc(description="Full workflow debug when the start node is a trigger")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DraftWorkflowTriggerRunAllRequest",
|
||||
{
|
||||
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Workflow executed successfully")
|
||||
@api.response(403, "Permission denied")
|
||||
@api.response(500, "Internal server error")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Full workflow debug when the start node is a trigger
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False)
|
||||
args = parser.parse_args()
|
||||
node_ids = args["node_ids"]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
try:
|
||||
trigger_debug_event: TriggerDebugEvent | None = select_trigger_debug_events(
|
||||
draft_workflow=draft_workflow,
|
||||
app_model=app_model,
|
||||
user_id=current_user.id,
|
||||
node_ids=node_ids,
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
return jsonable_encoder({"status": "error", "error": e.to_user_friendly_error()}), 400
|
||||
except Exception as e:
|
||||
logger.exception("Error polling trigger debug event")
|
||||
raise e
|
||||
if trigger_debug_event is None:
|
||||
return jsonable_encoder({"status": "waiting", "retry_in": LISTENING_RETRY_IN})
|
||||
|
||||
try:
|
||||
workflow_args = dict(trigger_debug_event.workflow_args)
|
||||
workflow_args[SKIP_PREPARE_USER_INPUTS_KEY] = True
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=workflow_args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
root_node_id=trigger_debug_event.node_id,
|
||||
)
|
||||
return helper.compact_generate_response(response)
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except Exception:
|
||||
logger.exception("Error running draft workflow trigger run-all")
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
|
||||
@ -28,6 +28,7 @@ class WorkflowAppLogApi(Resource):
|
||||
"created_at__after": "Filter logs created after this timestamp",
|
||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
||||
"created_by_account": "Filter by account",
|
||||
"detail": "Whether to return detailed logs",
|
||||
"page": "Page number (1-99999)",
|
||||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
@ -68,6 +69,7 @@ class WorkflowAppLogApi(Resource):
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("detail", type=bool, location="args", required=False, default=False)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
@ -92,6 +94,7 @@ class WorkflowAppLogApi(Resource):
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
detail=args.detail,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
@ -1,23 +1,26 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask import abort, jsonify
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_runs_statistic")
|
||||
@api.doc(description="Get workflow daily runs statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -37,57 +40,32 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(id) AS runs
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
response_data = self._workflow_run_repo.get_daily_runs_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||
class WorkflowDailyTerminalsStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_terminals_statistic")
|
||||
@api.doc(description="Get workflow daily terminals statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -107,57 +85,32 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||
class WorkflowDailyTokenCostStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_daily_token_cost_statistic")
|
||||
@api.doc(description="Get workflow daily token cost statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -177,62 +130,32 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) AS token_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
"date": str(i.date),
|
||||
"token_count": i.token_count,
|
||||
}
|
||||
)
|
||||
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
|
||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@api.doc("get_workflow_average_app_interaction_statistic")
|
||||
@api.doc(description="Get workflow average app interaction statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@ -252,67 +175,20 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
AVG(sub.interactions) AS interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM
|
||||
workflow_runs c
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY
|
||||
date, c.created_by
|
||||
) sub
|
||||
GROUP BY
|
||||
sub.date"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN,
|
||||
}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{start}}", "")
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
)
|
||||
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
145
api/controllers/console/app/workflow_trigger.py
Normal file
145
api/controllers/console/app/workflow_trigger.py
Normal file
@ -0,0 +1,145 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
|
||||
args = parser.parse_args()
|
||||
|
||||
node_id = str(args["node_id"])
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.where(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
trigger_id = args["trigger_id"]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
|
||||
|
||||
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
|
||||
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")
|
||||
@ -2,6 +2,7 @@ from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@ -16,7 +17,13 @@ class Subscription(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -746,7 +746,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@ -779,7 +779,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
|
||||
@ -4,7 +4,7 @@ from flask_restx import ( # type: ignore
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
@ -12,9 +12,17 @@ from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@api.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -26,12 +34,6 @@ class DataSourceContentPreviewApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
|
||||
@ -66,13 +66,7 @@ class APIBasedExtensionAPI(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
.add_argument("api_key", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = api.payload
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
@ -125,13 +119,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, location="json")
|
||||
.add_argument("api_endpoint", type=str, required=True, location="json")
|
||||
.add_argument("api_key", type=str, required=True, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = api.payload
|
||||
|
||||
extension_data_from_db.name = args["name"]
|
||||
extension_data_from_db.api_endpoint = args["api_endpoint"]
|
||||
|
||||
@ -8,6 +8,7 @@ import services
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
@ -39,6 +40,7 @@ class FileApi(Resource):
|
||||
return {
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
@ -82,6 +84,8 @@ class FileApi(Resource):
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as blocked_extension_error:
|
||||
raise BlockedFileExtensionError(blocked_extension_error.description)
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
@ -114,6 +114,25 @@ class PluginIconApi(Resource):
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/asset")
|
||||
class PluginAssetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
req = reqparse.RequestParser()
|
||||
req.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
req.add_argument("file_name", type=str, required=True, location="args")
|
||||
args = req.parse_args()
|
||||
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
try:
|
||||
binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
|
||||
return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/upload/pkg")
|
||||
class PluginUploadFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@ -558,19 +577,21 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
|
||||
.add_argument("provider", type=str, required=True, location="args")
|
||||
.add_argument("action", type=str, required=True, location="args")
|
||||
.add_argument("parameter", type=str, required=True, location="args")
|
||||
.add_argument("credential_id", type=str, required=False, location="args")
|
||||
.add_argument("provider_type", type=str, required=True, location="args")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
options = PluginParameterService.get_dynamic_select_options(
|
||||
tenant_id,
|
||||
user_id,
|
||||
args["plugin_id"],
|
||||
args["provider"],
|
||||
args["action"],
|
||||
args["parameter"],
|
||||
args["provider_type"],
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=args["plugin_id"],
|
||||
provider=args["provider"],
|
||||
action=args["action"],
|
||||
parameter=args["parameter"],
|
||||
credential_id=args["credential_id"],
|
||||
provider_type=args["provider_type"],
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
raise ValueError(e)
|
||||
@ -686,3 +707,23 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
args = req.parse_args()
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/readme")
|
||||
class PluginReadmeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||
parser.add_argument("language", type=str, required=False, location="args")
|
||||
args = parser.parse_args()
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"readme": PluginService.fetch_plugin_readme(
|
||||
tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
@ -6,6 +6,7 @@ from flask_restx import (
|
||||
Resource,
|
||||
reqparse,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@ -15,20 +16,23 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen, alphanumeric, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import ToolProviderID
|
||||
|
||||
# from models.provider_ids import ToolProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService, OAuthDataType
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@ -42,7 +46,9 @@ def is_valid_url(url: str) -> bool:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||
except Exception:
|
||||
except (ValueError, TypeError):
|
||||
# ValueError: Invalid URL format
|
||||
# TypeError: url is not a string
|
||||
return False
|
||||
|
||||
|
||||
@ -886,29 +892,34 @@ class ToolProviderMCPApi(Resource):
|
||||
.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("timeout", type=float, required=False, nullable=False, location="json", default=30)
|
||||
.add_argument("sse_read_timeout", type=float, required=False, nullable=False, location="json", default=300)
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
return jsonable_encoder(
|
||||
MCPToolManageService.create_mcp_provider(
|
||||
|
||||
# Parse and validate models
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
user_id=user.id,
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args["timeout"],
|
||||
sse_read_timeout=args["sse_read_timeout"],
|
||||
headers=args["headers"],
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
)
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -923,31 +934,43 @@ class ToolProviderMCPApi(Resource):
|
||||
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("timeout", type=float, required=False, nullable=True, location="json")
|
||||
.add_argument("sse_read_timeout", type=float, required=False, nullable=True, location="json")
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("configuration", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("headers", type=dict, required=False, nullable=True, location="json", default={})
|
||||
.add_argument("authentication", type=dict, required=False, nullable=True, location="json", default={})
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
MCPToolManageService.update_mcp_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
timeout=args.get("timeout"),
|
||||
sse_read_timeout=args.get("sse_read_timeout"),
|
||||
headers=args.get("headers"),
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
|
||||
validation_result = None
|
||||
with Session(db.engine) as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_result = service.validate_server_url_change(
|
||||
tenant_id=current_tenant_id, provider_id=args["provider_id"], new_server_url=args["server_url"]
|
||||
)
|
||||
|
||||
# No need to check for errors here, exceptions will be raised directly
|
||||
|
||||
# Step 2: Perform database update in a transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
headers=args["headers"],
|
||||
configuration=configuration,
|
||||
authentication=authentication,
|
||||
validation_result=validation_result,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -958,8 +981,11 @@ class ToolProviderMCPApi(Resource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")
|
||||
@ -976,37 +1002,53 @@ class ToolMCPAuthApi(Resource):
|
||||
args = parser.parse_args()
|
||||
provider_id = args["provider_id"]
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if not provider:
|
||||
raise ValueError("provider not found")
|
||||
try:
|
||||
with MCPClient(
|
||||
provider.decrypted_server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=False,
|
||||
authorization_code=args["authorization_code"],
|
||||
for_list=True,
|
||||
headers=provider.decrypted_headers,
|
||||
timeout=provider.timeout,
|
||||
sse_read_timeout=provider.sse_read_timeout,
|
||||
):
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
credentials=provider.decrypted_credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
except MCPAuthError:
|
||||
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
||||
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
raise ValueError("provider not found")
|
||||
|
||||
# Convert to entity
|
||||
provider_entity = db_provider.to_entity()
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_authentication()
|
||||
|
||||
# Try to connect without active transaction
|
||||
try:
|
||||
# Use MCPClientWithAuthRetry to handle authentication automatically
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Update credentials in new transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider_credentials(
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
credentials=provider_entity.credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
try:
|
||||
auth_result = auth(provider_entity, args.get("authorization_code"))
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except MCPError as e:
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
credentials={},
|
||||
authed=False,
|
||||
)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
@ -1017,8 +1059,10 @@ class ToolMCPDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tools/mcp")
|
||||
@ -1029,9 +1073,12 @@ class ToolMCPListAllApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
# Skip sensitive data decryption for list view to improve performance
|
||||
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
|
||||
|
||||
return [tool.to_dict() for tool in tools]
|
||||
return [tool.to_dict() for tool in tools]
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
||||
@ -1041,11 +1088,13 @@ class ToolMCPUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return jsonable_encoder(tools)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_provider_tools(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return jsonable_encoder(tools)
|
||||
|
||||
|
||||
@console_ns.route("/mcp/oauth/callback")
|
||||
@ -1059,5 +1108,15 @@ class ToolMCPCallbackApi(Resource):
|
||||
args = parser.parse_args()
|
||||
state_key = args["state"]
|
||||
authorization_code = args["code"]
|
||||
handle_callback(state_key, authorization_code)
|
||||
|
||||
# Create service instance for handle_callback
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
# handle_callback now returns state data and tokens
|
||||
state_data, tokens = handle_callback(state_key, authorization_code)
|
||||
# Save tokens using the service layer
|
||||
mcp_service.save_oauth_data(
|
||||
state_data.provider_id, state_data.tenant_id, tokens.model_dump(), OAuthDataType.TOKENS
|
||||
)
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
592
api/controllers/console/workspace/trigger_providers.py
Normal file
592
api/controllers/console/workspace/trigger_providers.py
Normal file
@ -0,0 +1,592 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.trigger.entities.entities import SubscriptionBuilderUpdater
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerProviderIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
|
||||
|
||||
|
||||
class TriggerProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List all trigger providers for the current tenant"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
|
||||
|
||||
|
||||
class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider))
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
return jsonable_encoder({"error": str(e)}), 404
|
||||
except Exception as e:
|
||||
logger.exception("Error listing trigger providers", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
credential_type=credential_type,
|
||||
)
|
||||
return jsonable_encoder({"subscription_builder": subscription_builder})
|
||||
except Exception as e:
|
||||
logger.exception("Error adding provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
)
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Verify a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Use atomic update_and_verify to prevent race conditions
|
||||
return TriggerSubscriptionBuilderService.update_and_verify_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error verifying provider credential", exc_info=e)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
credentials=args.get("credentials", None),
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error updating provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id)
|
||||
return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]})
|
||||
except Exception as e:
|
||||
logger.exception("Error getting request logs for subscription builder", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
# The name of the subscription builder
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
# The parameters of the subscription builder
|
||||
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
|
||||
# The properties of the subscription builder
|
||||
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
|
||||
# The credentials of the subscription builder
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
# Use atomic update_and_build to prevent race conditions
|
||||
TriggerSubscriptionBuilderService.update_and_build_builder(
|
||||
tenant_id=user.current_tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
name=args.get("name", None),
|
||||
parameters=args.get("parameters", None),
|
||||
properties=args.get("properties", None),
|
||||
),
|
||||
)
|
||||
return 200
|
||||
except Exception as e:
|
||||
logger.exception("Error building provider credential", exc_info=e)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
class TriggerSubscriptionDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, subscription_id: str):
|
||||
"""Delete a subscription instance"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
# Delete plugin triggers
|
||||
TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription(
|
||||
session=session,
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting provider credential", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise NotFoundError("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Create subscription builder
|
||||
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
provider_id=provider_id,
|
||||
credential_type=CredentialType.OAUTH2,
|
||||
)
|
||||
|
||||
# Create OAuth handler and proxy context
|
||||
oauth_handler = OAuthHandler()
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
extra_data={
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
},
|
||||
)
|
||||
|
||||
# Build redirect URI for callback
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
# Get authorization URL
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
)
|
||||
|
||||
# Create response with cookie
|
||||
response = make_response(
|
||||
jsonable_encoder(
|
||||
{
|
||||
"authorization_url": authorization_url_response.authorization_url,
|
||||
"subscription_builder_id": subscription_builder.id,
|
||||
"subscription_builder": subscription_builder,
|
||||
}
|
||||
)
|
||||
)
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error initiating OAuth flow", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
# Use and validate proxy context
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
# Parse provider ID
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
if oauth_client_params is None:
|
||||
raise Forbidden("No OAuth client configuration found for this trigger provider")
|
||||
|
||||
# Get OAuth credentials from callback
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
|
||||
credentials_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
|
||||
credentials = credentials_response.credentials
|
||||
expires_at = credentials_response.expires_at
|
||||
|
||||
if not credentials:
|
||||
raise ValueError("Failed to get OAuth credentials from the provider.")
|
||||
|
||||
# Update subscription builder
|
||||
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
subscription_builder_id=subscription_builder_id,
|
||||
subscription_builder_updater=SubscriptionBuilderUpdater(
|
||||
credentials=credentials,
|
||||
credential_expires_at=expires_at,
|
||||
),
|
||||
)
|
||||
# Redirect to OAuth callback page
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class TriggerOAuthClientManageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
# Get custom OAuth client params if exists
|
||||
custom_params = TriggerProviderService.get_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
# Check if custom client is enabled
|
||||
is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
system_client_exists = TriggerProviderService.is_oauth_system_client_exists(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id)
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback"
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"configured": bool(custom_params or system_client_exists),
|
||||
"system_configured": system_client_exists,
|
||||
"custom_configured": bool(custom_params),
|
||||
"oauth_client_schema": provider_controller.get_oauth_client_schema(),
|
||||
"custom_enabled": is_custom_enabled,
|
||||
"redirect_uri": redirect_uri,
|
||||
"params": custom_params or {},
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error getting OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
return TriggerProviderService.save_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
client_params=args.get("client_params"),
|
||||
enabled=args.get("enabled"),
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error configuring OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
provider_id = TriggerProviderID(provider)
|
||||
|
||||
return TriggerProviderService.delete_custom_oauth_client_params(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error removing OAuth client", exc_info=e)
|
||||
raise
|
||||
|
||||
|
||||
# Trigger Subscription
|
||||
api.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
|
||||
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
|
||||
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
|
||||
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
|
||||
api.add_resource(
|
||||
TriggerSubscriptionDeleteApi,
|
||||
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
|
||||
)
|
||||
|
||||
# Trigger Subscription Builder
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
|
||||
)
|
||||
|
||||
|
||||
# OAuth
|
||||
api.add_resource(
|
||||
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
|
||||
)
|
||||
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
|
||||
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")
|
||||
@ -21,6 +21,7 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -83,7 +84,7 @@ class TenantListApi(Resource):
|
||||
"name": tenant.name,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
|
||||
"current": tenant.id == current_tenant_id if current_tenant_id else False,
|
||||
}
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant
|
||||
@ -133,7 +134,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
abort(
|
||||
403,
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
|
||||
@ -14,10 +14,25 @@ from services.file_service import FileService
|
||||
|
||||
@files_ns.route("/<uuid:file_id>/image-preview")
|
||||
class ImagePreviewApi(Resource):
|
||||
"""
|
||||
Deprecated
|
||||
"""
|
||||
"""Deprecated endpoint for retrieving image previews."""
|
||||
|
||||
@files_ns.doc("get_image_preview")
|
||||
@files_ns.doc(description="Retrieve a signed image preview for a file")
|
||||
@files_ns.doc(
|
||||
params={
|
||||
"file_id": "ID of the file to preview",
|
||||
"timestamp": "Unix timestamp used in the signature",
|
||||
"nonce": "Random string used in the signature",
|
||||
"sign": "HMAC signature verifying the request",
|
||||
}
|
||||
)
|
||||
@files_ns.doc(
|
||||
responses={
|
||||
200: "Image preview returned successfully",
|
||||
400: "Missing or invalid signature parameters",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
@ -43,6 +58,25 @@ class ImagePreviewApi(Resource):
|
||||
|
||||
@files_ns.route("/<uuid:file_id>/file-preview")
|
||||
class FilePreviewApi(Resource):
|
||||
@files_ns.doc("get_file_preview")
|
||||
@files_ns.doc(description="Download a file preview or attachment using signed parameters")
|
||||
@files_ns.doc(
|
||||
params={
|
||||
"file_id": "ID of the file to preview",
|
||||
"timestamp": "Unix timestamp used in the signature",
|
||||
"nonce": "Random string used in the signature",
|
||||
"sign": "HMAC signature verifying the request",
|
||||
"as_attachment": "Whether to download the file as an attachment",
|
||||
}
|
||||
)
|
||||
@files_ns.doc(
|
||||
responses={
|
||||
200: "File stream returned successfully",
|
||||
400: "Missing or invalid signature parameters",
|
||||
404: "File not found",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
@ -101,6 +135,20 @@ class FilePreviewApi(Resource):
|
||||
|
||||
@files_ns.route("/workspaces/<uuid:workspace_id>/webapp-logo")
|
||||
class WorkspaceWebappLogoApi(Resource):
|
||||
@files_ns.doc("get_workspace_webapp_logo")
|
||||
@files_ns.doc(description="Fetch the custom webapp logo for a workspace")
|
||||
@files_ns.doc(
|
||||
params={
|
||||
"workspace_id": "Workspace identifier",
|
||||
}
|
||||
)
|
||||
@files_ns.doc(
|
||||
responses={
|
||||
200: "Logo returned successfully",
|
||||
404: "Webapp logo not configured",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, workspace_id):
|
||||
workspace_id = str(workspace_id)
|
||||
|
||||
|
||||
@ -13,6 +13,26 @@ from extensions.ext_database import db as global_db
|
||||
|
||||
@files_ns.route("/tools/<uuid:file_id>.<string:extension>")
|
||||
class ToolFileApi(Resource):
|
||||
@files_ns.doc("get_tool_file")
|
||||
@files_ns.doc(description="Download a tool file by ID using signed parameters")
|
||||
@files_ns.doc(
|
||||
params={
|
||||
"file_id": "Tool file identifier",
|
||||
"extension": "Expected file extension",
|
||||
"timestamp": "Unix timestamp used in the signature",
|
||||
"nonce": "Random string used in the signature",
|
||||
"sign": "HMAC signature verifying the request",
|
||||
"as_attachment": "Whether to download the file as an attachment",
|
||||
}
|
||||
)
|
||||
@files_ns.doc(
|
||||
responses={
|
||||
200: "Tool file stream returned successfully",
|
||||
403: "Forbidden - invalid signature",
|
||||
404: "File not found",
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id, extension):
|
||||
file_id = str(file_id)
|
||||
|
||||
|
||||
@ -592,7 +592,7 @@ class DocumentApi(DatasetApiResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@ -625,7 +625,7 @@ class DocumentApi(DatasetApiResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
|
||||
@ -2,6 +2,7 @@ from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
@ -107,6 +108,10 @@ class SegmentApi(DatasetApiResource):
|
||||
# validate args
|
||||
args = segment_create_parser.parse_args()
|
||||
if args["segments"] is not None:
|
||||
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
|
||||
if segments_limit > 0 and len(args["segments"]) > segments_limit:
|
||||
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
|
||||
|
||||
for args_item in args["segments"]:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
|
||||
@ -13,13 +13,15 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App, DefaultEndUserSessionID, EndUser
|
||||
from models.model import ApiToken, App
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
@ -66,6 +68,7 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
|
||||
kwargs["app_model"] = app_model
|
||||
|
||||
# If caller needs end-user context, attach EndUser to current_user
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get("user")
|
||||
@ -74,7 +77,6 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get("user")
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
@ -83,12 +85,34 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
end_user = EndUserService.get_or_create_end_user(app_model, user_id)
|
||||
kwargs["end_user"] = end_user
|
||||
|
||||
# Set EndUser as current logged-in user for flask_login.current_user
|
||||
current_app.login_manager._update_request_context_with_user(end_user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore
|
||||
else:
|
||||
# For service API without end-user context, ensure an Account is logged in
|
||||
# so services relying on current_account_with_tenant() work correctly.
|
||||
tenant_owner_info = (
|
||||
db.session.query(Tenant, Account)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.join(Account, TenantAccountJoin.account_id == Account.id)
|
||||
.where(
|
||||
Tenant.id == app_model.tenant_id,
|
||||
TenantAccountJoin.role == "owner",
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if tenant_owner_info:
|
||||
tenant_model, account = tenant_owner_info
|
||||
account.current_tenant = tenant_model
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account not found or tenant is not active.")
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
@ -138,7 +162,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
raise Forbidden(
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
|
||||
)
|
||||
@ -308,39 +332,6 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
return api_token
|
||||
|
||||
|
||||
def create_or_update_end_user_for_user_id(app_model: App, user_id: str | None = None) -> EndUser:
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
end_user = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == "service_api",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(end_user)
|
||||
session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
|
||||
12
api/controllers/trigger/__init__.py
Normal file
12
api/controllers/trigger/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
from flask import Blueprint
|
||||
|
||||
# Create trigger blueprint
|
||||
bp = Blueprint("trigger", __name__, url_prefix="/triggers")
|
||||
|
||||
# Import routes after blueprint creation to avoid circular imports
|
||||
from . import trigger, webhook
|
||||
|
||||
__all__ = [
|
||||
"trigger",
|
||||
"webhook",
|
||||
]
|
||||
43
api/controllers/trigger/trigger.py
Normal file
43
api/controllers/trigger/trigger.py
Normal file
@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from flask import jsonify, request
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.trigger import bp
|
||||
from services.trigger.trigger_service import TriggerService
|
||||
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
UUID_PATTERN = r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$"
|
||||
UUID_MATCHER = re.compile(UUID_PATTERN)
|
||||
|
||||
|
||||
@bp.route("/plugin/<string:endpoint_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def trigger_endpoint(endpoint_id: str):
|
||||
"""
|
||||
Handle endpoint trigger calls.
|
||||
"""
|
||||
# endpoint_id must be UUID
|
||||
if not UUID_MATCHER.match(endpoint_id):
|
||||
raise NotFound("Invalid endpoint ID")
|
||||
handling_chain = [
|
||||
TriggerService.process_endpoint,
|
||||
TriggerSubscriptionBuilderService.process_builder_validation_endpoint,
|
||||
]
|
||||
response = None
|
||||
try:
|
||||
for handler in handling_chain:
|
||||
response = handler(endpoint_id, request)
|
||||
if response:
|
||||
break
|
||||
if not response:
|
||||
logger.error("Endpoint not found for {endpoint_id}")
|
||||
return jsonify({"error": "Endpoint not found"}), 404
|
||||
return response
|
||||
except ValueError as e:
|
||||
return jsonify({"error": "Endpoint processing failed", "message": str(e)}), 400
|
||||
except Exception:
|
||||
logger.exception("Webhook processing failed for {endpoint_id}")
|
||||
return jsonify({"error": "Internal server error"}), 500
|
||||
105
api/controllers/trigger/webhook.py
Normal file
105
api/controllers/trigger/webhook.py
Normal file
@ -0,0 +1,105 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from flask import jsonify
|
||||
from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.trigger import bp
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
|
||||
"""Fetch trigger context, extract request data, and validate payload using unified processing.
|
||||
|
||||
Args:
|
||||
webhook_id: The webhook ID to process
|
||||
is_debug: If True, skip status validation for debug mode
|
||||
"""
|
||||
webhook_trigger, workflow, node_config = WebhookService.get_webhook_trigger_and_workflow(
|
||||
webhook_id, is_debug=is_debug
|
||||
)
|
||||
|
||||
try:
|
||||
# Use new unified extraction and validation
|
||||
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
return webhook_trigger, workflow, node_config, webhook_data, None
|
||||
except ValueError as e:
|
||||
# Fall back to raw extraction for error reporting
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
return webhook_trigger, workflow, node_config, webhook_data, str(e)
|
||||
|
||||
|
||||
@bp.route("/webhook/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook(webhook_id: str):
|
||||
"""
|
||||
Handle webhook trigger calls.
|
||||
|
||||
This endpoint receives webhook calls and processes them according to the
|
||||
configured webhook trigger settings.
|
||||
"""
|
||||
try:
|
||||
webhook_trigger, workflow, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id)
|
||||
if error:
|
||||
return jsonify({"error": "Bad Request", "message": error}), 400
|
||||
|
||||
# Process webhook call (send to Celery)
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
|
||||
# Return configured response
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except RequestEntityTooLarge:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Webhook processing failed for %s", webhook_id)
|
||||
return jsonify({"error": "Internal server error", "message": str(e)}), 500
|
||||
|
||||
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook_debug(webhook_id: str):
|
||||
"""Handle webhook debug calls without triggering production workflow execution."""
|
||||
try:
|
||||
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
|
||||
if error:
|
||||
return jsonify({"error": "Bad Request", "message": error}), 400
|
||||
|
||||
workflow_inputs = WebhookService.build_workflow_inputs(webhook_data)
|
||||
|
||||
# Generate pool key and dispatch debug event
|
||||
pool_key: str = build_webhook_pool_key(
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
app_id=webhook_trigger.app_id,
|
||||
node_id=webhook_trigger.node_id,
|
||||
)
|
||||
event = WebhookDebugEvent(
|
||||
request_id=f"webhook_debug_{webhook_trigger.webhook_id}_{int(time.time() * 1000)}",
|
||||
timestamp=int(time.time()),
|
||||
node_id=webhook_trigger.node_id,
|
||||
payload={
|
||||
"inputs": workflow_inputs,
|
||||
"webhook_data": webhook_data,
|
||||
"method": webhook_data.get("method"),
|
||||
},
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
event=event,
|
||||
pool_key=pool_key,
|
||||
)
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
except RequestEntityTooLarge:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Webhook debug processing failed for %s", webhook_id)
|
||||
return jsonify({"error": "Internal server error", "message": "An internal error has occurred."}), 500
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
@ -25,6 +25,7 @@ from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@ -61,11 +62,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
app: App,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
@ -195,6 +198,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
@ -60,6 +61,7 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
@ -391,6 +393,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if should_direct_answer:
|
||||
return
|
||||
|
||||
current_time = time.perf_counter()
|
||||
if self._task_state.first_token_time is None and delta_text.strip():
|
||||
self._task_state.first_token_time = current_time
|
||||
self._task_state.is_streaming_response = True
|
||||
|
||||
if delta_text.strip():
|
||||
self._task_state.last_token_time = current_time
|
||||
|
||||
# Only publish tts message at text chunk streaming
|
||||
if tts_publisher and queue_message:
|
||||
tts_publisher.publish(queue_message)
|
||||
@ -772,7 +782,33 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
message.answer = answer_text
|
||||
message.updated_at = naive_utc_now()
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline.start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
# Set usage first before dumping metadata
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
usage = LLMUsage.empty_usage()
|
||||
self._task_state.metadata.usage = usage
|
||||
|
||||
# Add streaming metrics to usage if available
|
||||
if self._task_state.is_streaming_response and self._task_state.first_token_time:
|
||||
start_time = self._base_task_pipeline.start_at
|
||||
first_token_time = self._task_state.first_token_time
|
||||
last_token_time = self._task_state.last_token_time or first_token_time
|
||||
usage.time_to_first_token = round(first_token_time - start_time, 3)
|
||||
usage.time_to_generate = round(last_token_time - first_token_time, 3)
|
||||
|
||||
metadata = self._task_state.metadata.model_dump()
|
||||
message.message_metadata = json.dumps(jsonable_encoder(metadata))
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
@ -790,20 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer_tokens = usage.completion_tokens
|
||||
message.answer_unit_price = usage.completion_unit_price
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@ -144,7 +144,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
@ -172,7 +172,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=dict(inputs),
|
||||
files=list(files),
|
||||
query=query or "",
|
||||
query=query,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ class AppRunner:
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: str | None = None,
|
||||
query: str = "",
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@ -105,7 +105,7 @@ class AppRunner:
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query or "",
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
|
||||
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
@ -37,6 +37,7 @@ from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
@ -51,7 +52,7 @@ from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from services.variable_truncator import VariableTruncator
|
||||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
|
||||
@ -70,6 +71,8 @@ class _NodeSnapshot:
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
_truncator: BaseTruncator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -81,7 +84,13 @@ class WorkflowResponseConverter:
|
||||
self._user = user
|
||||
self._system_variables = system_variables
|
||||
self._workflow_inputs = self._prepare_workflow_inputs()
|
||||
self._truncator = VariableTruncator.default()
|
||||
|
||||
# Disable truncation for SERVICE_API calls to keep backward compatibility.
|
||||
if application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||
self._truncator = DummyVariableTruncator()
|
||||
else:
|
||||
self._truncator = VariableTruncator.default()
|
||||
|
||||
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
|
||||
self._workflow_execution_id: str | None = None
|
||||
self._workflow_started_at: datetime | None = None
|
||||
@ -295,6 +304,11 @@ class WorkflowResponseConverter:
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
|
||||
self._application_generate_entity.app_config.tenant_id
|
||||
)
|
||||
elif event.node_type == NodeType.TRIGGER_PLUGIN:
|
||||
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
event.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -190,7 +190,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query or "",
|
||||
query=application_generate_entity.query,
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
|
||||
@ -41,18 +41,14 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.rag_pipeline_task_proxy import RagPipelineTaskProxy
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
from tasks.rag_pipeline.priority_rag_pipeline_run_task import priority_rag_pipeline_run_task
|
||||
from tasks.rag_pipeline.rag_pipeline_run_task import rag_pipeline_run_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -248,34 +244,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
if rag_pipeline_invoke_entities:
|
||||
# store the rag_pipeline_invoke_entities to object storage
|
||||
text = [item.model_dump() for item in rag_pipeline_invoke_entities]
|
||||
name = "rag_pipeline_invoke_entities.json"
|
||||
# Convert list to proper JSON string
|
||||
json_text = json.dumps(text)
|
||||
upload_file = FileService(db.engine).upload_text(json_text, name, user.id, dataset.tenant_id)
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
if features.billing.enabled and features.billing.subscription.plan == "sandbox":
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{dataset.tenant_id}"
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{dataset.tenant_id}"
|
||||
|
||||
if redis_client.get(tenant_pipeline_task_key):
|
||||
# Add to waiting queue using List operations (lpush)
|
||||
redis_client.lpush(tenant_self_pipeline_task_queue, upload_file.id)
|
||||
else:
|
||||
# Set flag and execute task
|
||||
redis_client.set(tenant_pipeline_task_key, 1, ex=60 * 60)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
else:
|
||||
priority_rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=upload_file.id,
|
||||
tenant_id=dataset.tenant_id,
|
||||
)
|
||||
|
||||
RagPipelineTaskProxy(dataset.tenant_id, user.id, rag_pipeline_invoke_entities).delay()
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
|
||||
@ -27,6 +27,7 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@ -38,10 +39,16 @@ from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTrigger
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
SKIP_PREPARE_USER_INPUTS_KEY = "_skip_prepare_user_inputs"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@staticmethod
|
||||
def _should_prepare_user_inputs(args: Mapping[str, Any]) -> bool:
|
||||
return not bool(args.get(SKIP_PREPARE_USER_INPUTS_KEY))
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@ -53,7 +60,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
@ -66,6 +76,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@ -79,7 +92,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -91,7 +107,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]:
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
@ -126,17 +145,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# for trigger debug run, not prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
inputs = self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
)
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=app_config.variables,
|
||||
tenant_id=app_model.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
inputs=inputs,
|
||||
files=list(system_files),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
@ -155,7 +177,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# Create session factory
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
# Create workflow execution(aka workflow run) repository
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
if triggered_from is not None:
|
||||
# Use explicitly provided triggered_from (for async triggers)
|
||||
workflow_triggered_from = triggered_from
|
||||
elif invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
@ -182,8 +207,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
|
||||
def resume(self, *, workflow_run_id: str) -> None:
|
||||
"""
|
||||
@TBD
|
||||
"""
|
||||
pass
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
@ -196,6 +229,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -231,8 +266,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": graph_engine_layers,
|
||||
},
|
||||
)
|
||||
|
||||
@ -426,6 +463,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@ -469,6 +508,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
system_user_id=system_user_id,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -8,6 +9,7 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
@ -16,6 +18,7 @@ from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
||||
@ -35,17 +38,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
root_node_id: str | None = None,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=variable_loader,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self._workflow = workflow
|
||||
self._sys_user_id = system_user_id
|
||||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
@ -60,6 +67,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
@ -92,6 +100,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
root_node_id=self._root_node_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
@ -135,6 +144,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -27,6 +27,7 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
@ -69,10 +70,12 @@ class WorkflowBasedAppRunner:
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
app_id: str,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
):
|
||||
self._queue_manager = queue_manager
|
||||
self._variable_loader = variable_loader
|
||||
self._app_id = app_id
|
||||
self._graph_engine_layers = graph_engine_layers
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
@ -81,6 +84,7 @@ class WorkflowBasedAppRunner:
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
root_node_id: str | None = None,
|
||||
) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
@ -114,7 +118,7 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
@ -32,6 +32,10 @@ class InvokeFrom(StrEnum):
|
||||
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
|
||||
WEB_APP = "web-app"
|
||||
|
||||
# TRIGGER indicates that this invocation is from a trigger.
|
||||
# this is used for plugin trigger and webhook trigger.
|
||||
TRIGGER = "trigger"
|
||||
|
||||
# EXPLORE indicates that this invocation is from
|
||||
# the workflow (or chatflow) explore page.
|
||||
EXPLORE = "explore"
|
||||
@ -40,6 +44,9 @@ class InvokeFrom(StrEnum):
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
|
||||
# VALIDATION indicates that this invocation is from validation.
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
@ -65,6 +72,8 @@ class InvokeFrom(StrEnum):
|
||||
return "dev"
|
||||
elif self == InvokeFrom.EXPLORE:
|
||||
return "explore_app"
|
||||
elif self == InvokeFrom.TRIGGER:
|
||||
return "trigger"
|
||||
elif self == InvokeFrom.SERVICE_API:
|
||||
return "api"
|
||||
|
||||
@ -104,6 +113,11 @@ class AppGenerateEntity(BaseModel):
|
||||
|
||||
inputs: Mapping[str, Any]
|
||||
files: Sequence[File]
|
||||
|
||||
# Unique identifier of the user initiating the execution.
|
||||
# This corresponds to `Account.id` for platform users or `EndUser.id` for end users.
|
||||
#
|
||||
# Note: The `user_id` field does not indicate whether the user is a platform user or an end user.
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
@ -129,7 +143,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
app_config: EasyUIBasedAppConfig = None # type: ignore
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
|
||||
query: str | None = None
|
||||
query: str = ""
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@ -48,6 +48,9 @@ class WorkflowTaskState(TaskState):
|
||||
"""
|
||||
|
||||
answer: str = ""
|
||||
first_token_time: float | None = None
|
||||
last_token_time: float | None = None
|
||||
is_streaming_response: bool = False
|
||||
|
||||
|
||||
class StreamEvent(StrEnum):
|
||||
|
||||
133
api/core/app/layers/pause_state_persist_layer.py
Normal file
133
api/core/app/layers/pause_state_persist_layer.py
Normal file
@ -0,0 +1,133 @@
|
||||
from typing import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from models.model import AppMode
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
# Wrapper types for `WorkflowAppGenerateEntity` and
|
||||
# `AdvancedChatAppGenerateEntity`. These wrappers enable type discrimination
|
||||
# and correct reconstruction of the entity field during (de)serialization.
|
||||
class _WorkflowGenerateEntityWrapper(BaseModel):
|
||||
type: Literal[AppMode.WORKFLOW] = AppMode.WORKFLOW
|
||||
entity: WorkflowAppGenerateEntity
|
||||
|
||||
|
||||
class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
|
||||
type: Literal[AppMode.ADVANCED_CHAT] = AppMode.ADVANCED_CHAT
|
||||
entity: AdvancedChatAppGenerateEntity
|
||||
|
||||
|
||||
_GenerateEntityUnion: TypeAlias = Annotated[
|
||||
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class WorkflowResumptionContext(BaseModel):
|
||||
"""WorkflowResumptionContext captures all state necessary for resumption."""
|
||||
|
||||
version: Literal["1"] = "1"
|
||||
|
||||
# Only workflow / chatflow could be paused.
|
||||
generate_entity: _GenerateEntityUnion
|
||||
serialized_graph_runtime_state: str
|
||||
|
||||
def dumps(self) -> str:
|
||||
return self.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def loads(cls, value: str) -> Self:
|
||||
return cls.model_validate_json(value)
|
||||
|
||||
def get_generate_entity(self) -> WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity:
|
||||
return self.generate_entity.entity
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: Engine | sessionmaker[Session],
|
||||
generate_entity: WorkflowAppGenerateEntity | AdvancedChatAppGenerateEntity,
|
||||
state_owner_user_id: str,
|
||||
):
|
||||
"""Create a PauseStatePersistenceLayer.
|
||||
|
||||
The `state_owner_user_id` is used when creating state file for pause.
|
||||
It generally should id of the creator of workflow.
|
||||
"""
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(session_factory)
|
||||
self._session_maker = session_factory
|
||||
self._state_owner_user_id = state_owner_user_id
|
||||
self._generate_entity = generate_entity
|
||||
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
|
||||
This is called after the engine has been initialized but before any nodes
|
||||
are executed. Layers can use this to set up resources or log start information.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
This method receives all events generated during graph execution, including:
|
||||
- Graph lifecycle events (start, success, failure)
|
||||
- Node execution events (start, success, failure, retry)
|
||||
- Stream events for response nodes
|
||||
- Container events (iteration, loop)
|
||||
|
||||
Args:
|
||||
event: The event emitted by the engine
|
||||
"""
|
||||
if not isinstance(event, GraphRunPausedEvent):
|
||||
return
|
||||
|
||||
assert self.graph_runtime_state is not None
|
||||
|
||||
entity_wrapper: _GenerateEntityUnion
|
||||
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
|
||||
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
|
||||
else:
|
||||
entity_wrapper = _AdvancedChatAppGenerateEntityWrapper(entity=self._generate_entity)
|
||||
|
||||
state = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=self.graph_runtime_state.dumps(),
|
||||
generate_entity=entity_wrapper,
|
||||
)
|
||||
|
||||
workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id is not None
|
||||
repo = self._get_repo()
|
||||
repo.create_workflow_pause(
|
||||
workflow_run_id=workflow_run_id,
|
||||
state_owner_user_id=self._state_owner_user_id,
|
||||
state=state.dumps(),
|
||||
)
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
This is called after all nodes have been executed or when execution is
|
||||
aborted. Layers can use this to clean up resources or log final state.
|
||||
|
||||
Args:
|
||||
error: The exception that caused execution to fail, or None if successful
|
||||
"""
|
||||
pass
|
||||
21
api/core/app/layers/suspend_layer.py
Normal file
21
api/core/app/layers/suspend_layer.py
Normal file
@ -0,0 +1,21 @@
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
|
||||
|
||||
class SuspendLayer(GraphEngineLayer):
|
||||
""" """
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
"""
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
pass
|
||||
88
api/core/app/layers/timeslice_layer.py
Normal file
88
api/core/app/layers/timeslice_layer.py
Normal file
@ -0,0 +1,88 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import ClassVar
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
|
||||
|
||||
from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from services.workflow.entities import WorkflowScheduleCFSPlanEntity
|
||||
from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimeSliceLayer(GraphEngineLayer):
|
||||
"""
|
||||
CFS plan scheduler to control the timeslice of the workflow.
|
||||
"""
|
||||
|
||||
scheduler: ClassVar[BackgroundScheduler] = BackgroundScheduler()
|
||||
|
||||
def __init__(self, cfs_plan_scheduler: CFSPlanScheduler) -> None:
|
||||
"""
|
||||
CFS plan scheduler allows to control the timeslice of the workflow.
|
||||
"""
|
||||
|
||||
if not TimeSliceLayer.scheduler.running:
|
||||
TimeSliceLayer.scheduler.start()
|
||||
|
||||
super().__init__()
|
||||
self.cfs_plan_scheduler = cfs_plan_scheduler
|
||||
self.stopped = False
|
||||
self.schedule_id = ""
|
||||
|
||||
def _checker_job(self, schedule_id: str):
|
||||
"""
|
||||
Check if the workflow need to be suspended.
|
||||
"""
|
||||
try:
|
||||
if self.stopped:
|
||||
self.scheduler.remove_job(schedule_id)
|
||||
return
|
||||
|
||||
if self.cfs_plan_scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED:
|
||||
# remove the job
|
||||
self.scheduler.remove_job(schedule_id)
|
||||
|
||||
if not self.command_channel:
|
||||
logger.exception("No command channel to stop the workflow")
|
||||
return
|
||||
|
||||
# send command to pause the workflow
|
||||
self.command_channel.send_command(
|
||||
GraphEngineCommand(
|
||||
command_type=CommandType.PAUSE,
|
||||
payload={
|
||||
"reason": SchedulerCommand.RESOURCE_LIMIT_REACHED,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("scheduler error during check if the workflow need to be suspended")
|
||||
|
||||
def on_graph_start(self):
|
||||
"""
|
||||
Start timer to check if the workflow need to be suspended.
|
||||
"""
|
||||
|
||||
if self.cfs_plan_scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice:
|
||||
self.schedule_id = uuid.uuid4().hex
|
||||
|
||||
self.scheduler.add_job(
|
||||
lambda: self._checker_job(self.schedule_id),
|
||||
"interval",
|
||||
seconds=self.cfs_plan_scheduler.plan.granularity,
|
||||
id=self.schedule_id,
|
||||
)
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self.stopped = True
|
||||
# remove the scheduler
|
||||
if self.schedule_id:
|
||||
self.scheduler.remove_job(self.schedule_id)
|
||||
88
api/core/app/layers/trigger_post_layer.py
Normal file
88
api/core/app/layers/trigger_post_layer.py
Normal file
@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events.base import GraphEngineEvent
|
||||
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
|
||||
from models.enums import WorkflowTriggerStatus
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerPostLayer(GraphEngineLayer):
|
||||
"""
|
||||
Trigger post layer.
|
||||
"""
|
||||
|
||||
_STATUS_MAP: ClassVar[dict[type[GraphEngineEvent], WorkflowTriggerStatus]] = {
|
||||
GraphRunSucceededEvent: WorkflowTriggerStatus.SUCCEEDED,
|
||||
GraphRunFailedEvent: WorkflowTriggerStatus.FAILED,
|
||||
GraphRunPausedEvent: WorkflowTriggerStatus.PAUSED,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
|
||||
start_time: datetime,
|
||||
trigger_log_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
):
|
||||
self.trigger_log_id = trigger_log_id
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
self.session_maker = session_maker
|
||||
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Update trigger log with success or failure.
|
||||
"""
|
||||
if isinstance(event, tuple(self._STATUS_MAP.keys())):
|
||||
with self.session_maker() as session:
|
||||
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = repo.get_by_id(self.trigger_log_id)
|
||||
if not trigger_log:
|
||||
logger.exception("Trigger log not found: %s", self.trigger_log_id)
|
||||
return
|
||||
|
||||
# Calculate elapsed time
|
||||
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
|
||||
|
||||
# Extract relevant data from result
|
||||
if not self.graph_runtime_state:
|
||||
logger.exception("Graph runtime state is not set")
|
||||
return
|
||||
|
||||
outputs = self.graph_runtime_state.outputs
|
||||
|
||||
# BASICLY, workflow_execution_id is the same as workflow_run_id
|
||||
workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id
|
||||
assert workflow_run_id, "Workflow run id is not set"
|
||||
|
||||
total_tokens = self.graph_runtime_state.total_tokens
|
||||
|
||||
# Update trigger log with success
|
||||
trigger_log.status = self._STATUS_MAP[type(event)]
|
||||
trigger_log.workflow_run_id = workflow_run_id
|
||||
trigger_log.outputs = TypeAdapter(dict[str, Any]).dump_json(outputs).decode()
|
||||
|
||||
if trigger_log.elapsed_time is None:
|
||||
trigger_log.elapsed_time = elapsed_time
|
||||
else:
|
||||
trigger_log.elapsed_time += elapsed_time
|
||||
|
||||
trigger_log.total_tokens = total_tokens
|
||||
trigger_log.finished_at = datetime.now(UTC)
|
||||
repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
@ -121,7 +121,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
@ -140,7 +140,27 @@ class MessageCycleManager:
|
||||
if not self._application_generate_entity.app_config.additional_features:
|
||||
raise ValueError("Additional features not found")
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
merged_resources = [r for r in self._task_state.metadata.retriever_resources or [] if r]
|
||||
existing_ids = {(r.dataset_id, r.document_id) for r in merged_resources if r.dataset_id and r.document_id}
|
||||
|
||||
# Add new unique resources from the event
|
||||
for resource in event.retriever_resources or []:
|
||||
if not resource:
|
||||
continue
|
||||
|
||||
is_duplicate = (
|
||||
resource.dataset_id
|
||||
and resource.document_id
|
||||
and (resource.dataset_id, resource.document_id) in existing_ids
|
||||
)
|
||||
|
||||
if not is_duplicate:
|
||||
merged_resources.append(resource)
|
||||
|
||||
for i, resource in enumerate(merged_resources, 1):
|
||||
resource.position = i
|
||||
|
||||
self._task_state.metadata.retriever_resources = merged_resources
|
||||
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> MessageFileStreamResponse | None:
|
||||
"""
|
||||
|
||||
15
api/core/entities/document_task.py
Normal file
15
api/core/entities/document_task.py
Normal file
@ -0,0 +1,15 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentTask:
|
||||
"""Document task entity for document indexing operations.
|
||||
|
||||
This class represents a document indexing task that can be queued
|
||||
and processed by the document indexing system.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
document_ids: Sequence[str]
|
||||
329
api/core/entities/mcp_provider.py
Normal file
329
api/core/entities/mcp_provider.py
Normal file
@ -0,0 +1,329 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.tools import MCPToolProvider
|
||||
|
||||
# Constants
|
||||
CLIENT_NAME = "Dify"
|
||||
CLIENT_URI = "https://github.com/langgenius/dify"
|
||||
DEFAULT_TOKEN_TYPE = "Bearer"
|
||||
DEFAULT_EXPIRES_IN = 3600
|
||||
MASK_CHAR = "*"
|
||||
MIN_UNMASK_LENGTH = 6
|
||||
|
||||
|
||||
class MCPSupportGrantType(StrEnum):
|
||||
"""The supported grant types for MCP"""
|
||||
|
||||
AUTHORIZATION_CODE = "authorization_code"
|
||||
CLIENT_CREDENTIALS = "client_credentials"
|
||||
REFRESH_TOKEN = "refresh_token"
|
||||
|
||||
|
||||
class MCPAuthentication(BaseModel):
|
||||
client_id: str
|
||||
client_secret: str | None = None
|
||||
|
||||
|
||||
class MCPConfiguration(BaseModel):
|
||||
timeout: float = 30
|
||||
sse_read_timeout: float = 300
|
||||
|
||||
|
||||
class MCPProviderEntity(BaseModel):
|
||||
"""MCP Provider domain entity for business logic operations"""
|
||||
|
||||
# Basic identification
|
||||
id: str
|
||||
provider_id: str # server_identifier
|
||||
name: str
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
|
||||
# Server connection info
|
||||
server_url: str # encrypted URL
|
||||
headers: dict[str, str] # encrypted headers
|
||||
timeout: float
|
||||
sse_read_timeout: float
|
||||
|
||||
# Authentication related
|
||||
authed: bool
|
||||
credentials: dict[str, Any] # encrypted credentials
|
||||
code_verifier: str | None = None # for OAuth
|
||||
|
||||
# Tools and display info
|
||||
tools: list[dict[str, Any]] # parsed tools list
|
||||
icon: str | dict[str, str] # parsed icon
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, db_provider: "MCPToolProvider") -> "MCPProviderEntity":
|
||||
"""Create entity from database model with decryption"""
|
||||
|
||||
return cls(
|
||||
id=db_provider.id,
|
||||
provider_id=db_provider.server_identifier,
|
||||
name=db_provider.name,
|
||||
tenant_id=db_provider.tenant_id,
|
||||
user_id=db_provider.user_id,
|
||||
server_url=db_provider.server_url,
|
||||
headers=db_provider.headers,
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
authed=db_provider.authed,
|
||||
credentials=db_provider.credentials,
|
||||
tools=db_provider.tool_dict,
|
||||
icon=db_provider.icon or "",
|
||||
created_at=db_provider.created_at,
|
||||
updated_at=db_provider.updated_at,
|
||||
)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""OAuth redirect URL"""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
# Get grant type from credentials
|
||||
credentials = self.decrypt_credentials()
|
||||
|
||||
# Try to get grant_type from different locations
|
||||
grant_type = credentials.get("grant_type", MCPSupportGrantType.AUTHORIZATION_CODE)
|
||||
|
||||
# For nested structure, check if client_information has grant_types
|
||||
if "client_information" in credentials and isinstance(credentials["client_information"], dict):
|
||||
client_info = credentials["client_information"]
|
||||
# If grant_types is specified in client_information, use it to determine grant_type
|
||||
if "grant_types" in client_info and isinstance(client_info["grant_types"], list):
|
||||
if "client_credentials" in client_info["grant_types"]:
|
||||
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||
elif "authorization_code" in client_info["grant_types"]:
|
||||
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE
|
||||
|
||||
# Configure based on grant type
|
||||
is_client_credentials = grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS
|
||||
|
||||
grant_types = ["refresh_token"]
|
||||
grant_types.append("client_credentials" if is_client_credentials else "authorization_code")
|
||||
|
||||
response_types = [] if is_client_credentials else ["code"]
|
||||
redirect_uris = [] if is_client_credentials else [self.redirect_url]
|
||||
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=redirect_uris,
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=grant_types,
|
||||
response_types=response_types,
|
||||
client_name=CLIENT_NAME,
|
||||
client_uri=CLIENT_URI,
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_icon(self) -> dict[str, str] | str:
|
||||
"""Get provider icon, handling both dict and string formats"""
|
||||
if isinstance(self.icon, dict):
|
||||
return self.icon
|
||||
try:
|
||||
return json.loads(self.icon)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If not JSON, assume it's a file path
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
|
||||
def to_api_response(self, user_name: str | None = None, include_sensitive: bool = True) -> dict[str, Any]:
|
||||
"""Convert to API response format
|
||||
|
||||
Args:
|
||||
user_name: User name to display
|
||||
include_sensitive: If False, skip expensive decryption operations (for list view optimization)
|
||||
"""
|
||||
response = {
|
||||
"id": self.id,
|
||||
"author": user_name or "Anonymous",
|
||||
"name": self.name,
|
||||
"icon": self.provider_icon,
|
||||
"type": ToolProviderType.MCP.value,
|
||||
"is_team_authorization": self.authed,
|
||||
"server_url": self.masked_server_url(),
|
||||
"server_identifier": self.provider_id,
|
||||
"updated_at": int(self.updated_at.timestamp()),
|
||||
"label": I18nObject(en_US=self.name, zh_Hans=self.name).to_dict(),
|
||||
"description": I18nObject(en_US="", zh_Hans="").to_dict(),
|
||||
}
|
||||
|
||||
# Add configuration
|
||||
response["configuration"] = {
|
||||
"timeout": str(self.timeout),
|
||||
"sse_read_timeout": str(self.sse_read_timeout),
|
||||
}
|
||||
|
||||
# Skip expensive operations when sensitive data is not needed (e.g., list view)
|
||||
if not include_sensitive:
|
||||
response["masked_headers"] = {}
|
||||
response["is_dynamic_registration"] = True
|
||||
else:
|
||||
# Add masked headers
|
||||
response["masked_headers"] = self.masked_headers()
|
||||
|
||||
# Add authentication info if available
|
||||
masked_creds = self.masked_credentials()
|
||||
if masked_creds:
|
||||
response["authentication"] = masked_creds
|
||||
response["is_dynamic_registration"] = self.credentials.get("client_information", {}).get(
|
||||
"is_dynamic_registration", True
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def retrieve_client_information(self) -> OAuthClientInformation | None:
|
||||
"""OAuth client information if available"""
|
||||
credentials = self.decrypt_credentials()
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
# Check if we have nested client_information structure
|
||||
if "client_information" not in credentials:
|
||||
return None
|
||||
client_info_data = credentials["client_information"]
|
||||
if isinstance(client_info_data, dict):
|
||||
if "encrypted_client_secret" in client_info_data:
|
||||
client_info_data["client_secret"] = encrypter.decrypt_token(
|
||||
self.tenant_id, client_info_data["encrypted_client_secret"]
|
||||
)
|
||||
return OAuthClientInformation.model_validate(client_info_data)
|
||||
return None
|
||||
|
||||
def retrieve_tokens(self) -> OAuthTokens | None:
|
||||
"""OAuth tokens if available"""
|
||||
if not self.credentials:
|
||||
return None
|
||||
credentials = self.decrypt_credentials()
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", DEFAULT_TOKEN_TYPE),
|
||||
expires_in=int(credentials.get("expires_in", str(DEFAULT_EXPIRES_IN)) or DEFAULT_EXPIRES_IN),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def masked_server_url(self) -> str:
|
||||
"""Masked server URL for display"""
|
||||
parsed = urlparse(self.decrypt_server_url())
|
||||
if parsed.path and parsed.path != "/":
|
||||
masked = parsed._replace(path="/******")
|
||||
return masked.geturl()
|
||||
return parsed.geturl()
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
"""Mask a sensitive value for display"""
|
||||
if len(value) > MIN_UNMASK_LENGTH:
|
||||
return value[:2] + MASK_CHAR * (len(value) - 4) + value[-2:]
|
||||
else:
|
||||
return MASK_CHAR * len(value)
|
||||
|
||||
def masked_headers(self) -> dict[str, str]:
|
||||
"""Masked headers for display"""
|
||||
return {key: self._mask_value(value) for key, value in self.decrypt_headers().items()}
|
||||
|
||||
def masked_credentials(self) -> dict[str, str]:
|
||||
"""Masked credentials for display"""
|
||||
credentials = self.decrypt_credentials()
|
||||
if not credentials:
|
||||
return {}
|
||||
|
||||
masked = {}
|
||||
|
||||
if "client_information" not in credentials or not isinstance(credentials["client_information"], dict):
|
||||
return {}
|
||||
client_info = credentials["client_information"]
|
||||
# Mask sensitive fields from nested structure
|
||||
if client_info.get("client_id"):
|
||||
masked["client_id"] = self._mask_value(client_info["client_id"])
|
||||
if client_info.get("encrypted_client_secret"):
|
||||
masked["client_secret"] = self._mask_value(
|
||||
encrypter.decrypt_token(self.tenant_id, client_info["encrypted_client_secret"])
|
||||
)
|
||||
if client_info.get("client_secret"):
|
||||
masked["client_secret"] = self._mask_value(client_info["client_secret"])
|
||||
return masked
|
||||
|
||||
def decrypt_server_url(self) -> str:
|
||||
"""Decrypt server URL"""
|
||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||
|
||||
def _decrypt_dict(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Generic method to decrypt dictionary fields"""
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
# Only decrypt fields that are actually encrypted
|
||||
# For nested structures, client_information is not encrypted as a whole
|
||||
encrypted_fields = []
|
||||
for key, value in data.items():
|
||||
# Skip nested objects - they are not encrypted
|
||||
if isinstance(value, dict):
|
||||
continue
|
||||
# Only process string values that might be encrypted
|
||||
if isinstance(value, str) and value:
|
||||
encrypted_fields.append(key)
|
||||
|
||||
if not encrypted_fields:
|
||||
return data
|
||||
|
||||
# Create dynamic config only for encrypted fields
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in encrypted_fields]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# Decrypt only the encrypted fields
|
||||
decrypted_data = encrypter_instance.decrypt({k: data[k] for k in encrypted_fields})
|
||||
|
||||
# Merge decrypted data with original data (preserving non-encrypted fields)
|
||||
result = data.copy()
|
||||
result.update(decrypted_data)
|
||||
|
||||
return result
|
||||
|
||||
def decrypt_headers(self) -> dict[str, Any]:
|
||||
"""Decrypt headers"""
|
||||
return self._decrypt_dict(self.headers)
|
||||
|
||||
def decrypt_credentials(self) -> dict[str, Any]:
|
||||
"""Decrypt credentials"""
|
||||
return self._decrypt_dict(self.credentials)
|
||||
|
||||
def decrypt_authentication(self) -> dict[str, Any]:
|
||||
"""Decrypt authentication"""
|
||||
# Option 1: if headers is provided, use it and don't need to get token
|
||||
headers = self.decrypt_headers()
|
||||
|
||||
# Option 2: Add OAuth token if authed and no headers provided
|
||||
if not self.headers and self.authed:
|
||||
token = self.retrieve_tokens()
|
||||
if token:
|
||||
headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
return headers
|
||||
@ -14,6 +14,7 @@ class CommonParameterType(StrEnum):
|
||||
APP_SELECTOR = "app-selector"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
TOOLS_SELECTOR = "array[tools]"
|
||||
CHECKBOX = "checkbox"
|
||||
ANY = auto()
|
||||
|
||||
# Dynamic select parameter
|
||||
|
||||
@ -1533,6 +1533,9 @@ class ProviderConfiguration(BaseModel):
|
||||
# Return composite sort key: (model_type value, model position index)
|
||||
return (model.model_type.value, position_index)
|
||||
|
||||
# Deduplicate
|
||||
provider_models = list({(m.model, m.model_type, m.fetch_from): m for m in provider_models}.values())
|
||||
|
||||
# Sort using the composite sort key
|
||||
return sorted(provider_models, key=get_sort_key)
|
||||
|
||||
|
||||
@ -107,7 +107,7 @@ class CustomModelConfiguration(BaseModel):
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict | None = None
|
||||
credentials: dict | None
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
@ -207,6 +207,7 @@ class ProviderConfig(BasicProviderConfig):
|
||||
required: bool = False
|
||||
default: Union[int, str, float, bool] | None = None
|
||||
options: list[Option] | None = None
|
||||
multiple: bool | None = False
|
||||
label: I18nObject | None = None
|
||||
help: I18nObject | None = None
|
||||
url: str | None = None
|
||||
|
||||
@ -74,6 +74,10 @@ class File(BaseModel):
|
||||
storage_key: str | None = None,
|
||||
dify_model_identity: str | None = FILE_MODEL_IDENTITY,
|
||||
url: str | None = None,
|
||||
# Legacy compatibility fields - explicitly handle known extra fields
|
||||
tool_file_id: str | None = None,
|
||||
upload_file_id: str | None = None,
|
||||
datasource_file_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
|
||||
@ -6,10 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(
|
||||
f"""
|
||||
// declare main function
|
||||
{cls._code_placeholder}
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
// decode and prepare input object
|
||||
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
|
||||
@ -21,6 +18,5 @@ class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
var output_json = JSON.stringify(output_obj)
|
||||
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
|
||||
console.log(result)
|
||||
"""
|
||||
)
|
||||
""")
|
||||
return runner_script
|
||||
|
||||
@ -6,9 +6,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
class Python3TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
# declare main function
|
||||
{cls._code_placeholder}
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
import json
|
||||
from base64 import b64decode
|
||||
|
||||
@ -29,6 +29,18 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
||||
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def batch_fetch_plugin_by_ids(plugin_ids: list[str]) -> list[dict]:
|
||||
if not plugin_ids:
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("data", {}).get("plugins", [])
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
plugin_ids: list[str],
|
||||
) -> Sequence[MarketplacePluginDeclaration]:
|
||||
|
||||
@ -3,7 +3,7 @@ import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
129
api/core/helper/provider_encryption.py
Normal file
129
api/core/helper/provider_encryption.py
Normal file
@ -0,0 +1,129 @@
|
||||
import contextlib
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = dict(self._deep_copy(data))
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
mask credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = dict(self._deep_copy(data))
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def mask_plugin_credentials(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
return self.mask_credentials(data)
|
||||
|
||||
def decrypt(self, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = dict(self._deep_copy(data))
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
with contextlib.suppress(Exception):
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
|
||||
self.provider_config_cache.set(dict(data))
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
@ -6,11 +6,15 @@ import secrets
|
||||
import urllib.parse
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from httpx import ConnectError, HTTPStatusError, RequestError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.entities.mcp_provider import MCPProviderEntity, MCPSupportGrantType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.mcp.entities import AuthAction, AuthActionType, AuthResult, OAuthCallbackState
|
||||
from core.mcp.error import MCPRefreshTokenError
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
@ -19,21 +23,10 @@ from core.mcp.types import (
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
@ -80,8 +73,13 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
def handle_callback(state_key: str, authorization_code: str) -> tuple[OAuthCallbackState, OAuthTokens]:
|
||||
"""
|
||||
Handle the callback from the OAuth provider.
|
||||
|
||||
Returns:
|
||||
A tuple of (callback_state, tokens) that can be used by the caller to save data.
|
||||
"""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
@ -93,30 +91,32 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
||||
provider.save_tokens(tokens)
|
||||
return full_state_data
|
||||
|
||||
return full_state_data, tokens
|
||||
|
||||
|
||||
def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
|
||||
"""Check if the server supports OAuth 2.0 Resource Discovery."""
|
||||
b_scheme, b_netloc, b_path, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource{b_path}"
|
||||
b_scheme, b_netloc, _, _, b_query, b_fragment = urlparse(server_url, "", True)
|
||||
url_for_resource_discovery = f"{b_scheme}://{b_netloc}/.well-known/oauth-protected-resource"
|
||||
if b_query:
|
||||
url_for_resource_discovery += f"?{b_query}"
|
||||
if b_fragment:
|
||||
url_for_resource_discovery += f"#{b_fragment}"
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"}
|
||||
response = httpx.get(url_for_resource_discovery, headers=headers)
|
||||
response = ssrf_proxy.get(url_for_resource_discovery, headers=headers)
|
||||
if 200 <= response.status_code < 300:
|
||||
body = response.json()
|
||||
if "authorization_server_url" in body:
|
||||
# Support both singular and plural forms
|
||||
if body.get("authorization_servers"):
|
||||
return True, body["authorization_servers"][0]
|
||||
elif body.get("authorization_server_url"):
|
||||
return True, body["authorization_server_url"][0]
|
||||
else:
|
||||
return False, ""
|
||||
return False, ""
|
||||
except httpx.RequestError:
|
||||
except RequestError:
|
||||
# Not support resource discovery, fall back to well-known OAuth metadata
|
||||
return False, ""
|
||||
|
||||
@ -126,27 +126,37 @@ def discover_oauth_metadata(server_url: str, protocol_version: str | None = None
|
||||
# First check if the server supports OAuth 2.0 Resource Discovery
|
||||
support_resource_discovery, oauth_discovery_url = check_support_resource_discovery(server_url)
|
||||
if support_resource_discovery:
|
||||
url = oauth_discovery_url
|
||||
# The oauth_discovery_url is the authorization server base URL
|
||||
# Try OpenID Connect discovery first (more common), then OAuth 2.0
|
||||
urls_to_try = [
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/oauth-authorization-server"),
|
||||
urljoin(oauth_discovery_url + "/", ".well-known/openid-configuration"),
|
||||
]
|
||||
else:
|
||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
urls_to_try = [urljoin(server_url, "/.well-known/oauth-authorization-server")]
|
||||
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
response = httpx.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except httpx.RequestError as e:
|
||||
if isinstance(e, httpx.ConnectError):
|
||||
response = httpx.get(url)
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
|
||||
for url in urls_to_try:
|
||||
try:
|
||||
response = ssrf_proxy.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
continue
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
response.raise_for_status()
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
raise
|
||||
except (RequestError, HTTPStatusError) as e:
|
||||
if isinstance(e, ConnectError):
|
||||
response = ssrf_proxy.get(url)
|
||||
if response.status_code == 404:
|
||||
continue # Try next URL
|
||||
if not response.is_success:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
# For other errors, try next URL
|
||||
continue
|
||||
|
||||
return None # No metadata found
|
||||
|
||||
|
||||
def start_authorization(
|
||||
@ -213,7 +223,7 @@ def exchange_authorization(
|
||||
redirect_uri: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchanges an authorization code for an access token."""
|
||||
grant_type = "authorization_code"
|
||||
grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
@ -233,7 +243,7 @@ def exchange_authorization(
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, data=params)
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
@ -246,7 +256,7 @@ def refresh_authorization(
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange a refresh token for an updated access token."""
|
||||
grant_type = "refresh_token"
|
||||
grant_type = MCPSupportGrantType.REFRESH_TOKEN.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
@ -263,10 +273,55 @@ def refresh_authorization(
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = httpx.post(token_url, data=params)
|
||||
try:
|
||||
response = ssrf_proxy.post(token_url, data=params)
|
||||
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||
raise MCPRefreshTokenError(e) from e
|
||||
if not response.is_success:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
raise MCPRefreshTokenError(response.text)
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def client_credentials_flow(
|
||||
server_url: str,
|
||||
metadata: OAuthMetadata | None,
|
||||
client_information: OAuthClientInformation,
|
||||
scope: str | None = None,
|
||||
) -> OAuthTokens:
|
||||
"""Execute Client Credentials Flow to get access token."""
|
||||
grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
# Support both Basic Auth and body parameters for client authentication
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
data = {"grant_type": grant_type}
|
||||
|
||||
if scope:
|
||||
data["scope"] = scope
|
||||
|
||||
# If client_secret is provided, use Basic Auth (preferred method)
|
||||
if client_information.client_secret:
|
||||
credentials = f"{client_information.client_id}:{client_information.client_secret}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
else:
|
||||
# Fall back to including credentials in the body
|
||||
data["client_id"] = client_information.client_id
|
||||
if client_information.client_secret:
|
||||
data["client_secret"] = client_information.client_secret
|
||||
|
||||
response = ssrf_proxy.post(token_url, headers=headers, data=data)
|
||||
if not response.is_success:
|
||||
raise ValueError(
|
||||
f"Client credentials token request failed: HTTP {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
@ -283,7 +338,7 @@ def register_client(
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = httpx.post(
|
||||
response = ssrf_proxy.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
@ -294,28 +349,111 @@ def register_client(
|
||||
|
||||
|
||||
def auth(
|
||||
provider: OAuthClientProvider,
|
||||
server_url: str,
|
||||
provider: MCPProviderEntity,
|
||||
authorization_code: str | None = None,
|
||||
state_param: str | None = None,
|
||||
for_list: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Orchestrates the full auth flow with a server using secure Redis state storage.
|
||||
|
||||
This function performs only network operations and returns actions that need
|
||||
to be performed by the caller (such as saving data to database).
|
||||
|
||||
Args:
|
||||
provider: The MCP provider entity
|
||||
authorization_code: Optional authorization code from OAuth callback
|
||||
state_param: Optional state parameter from OAuth callback
|
||||
|
||||
Returns:
|
||||
AuthResult containing actions to be performed and response data
|
||||
"""
|
||||
actions: list[AuthAction] = []
|
||||
server_url = provider.decrypt_server_url()
|
||||
server_metadata = discover_oauth_metadata(server_url)
|
||||
client_metadata = provider.client_metadata
|
||||
provider_id = provider.id
|
||||
tenant_id = provider.tenant_id
|
||||
client_information = provider.retrieve_client_information()
|
||||
redirect_url = provider.redirect_url
|
||||
|
||||
# Determine grant type based on server metadata
|
||||
if not server_metadata:
|
||||
raise ValueError("Failed to discover OAuth metadata from server")
|
||||
|
||||
supported_grant_types = server_metadata.grant_types_supported or []
|
||||
|
||||
# Convert to lowercase for comparison
|
||||
supported_grant_types_lower = [gt.lower() for gt in supported_grant_types]
|
||||
|
||||
# Determine which grant type to use
|
||||
effective_grant_type = None
|
||||
if MCPSupportGrantType.AUTHORIZATION_CODE.value in supported_grant_types_lower:
|
||||
effective_grant_type = MCPSupportGrantType.AUTHORIZATION_CODE.value
|
||||
else:
|
||||
effective_grant_type = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
# Get stored credentials
|
||||
credentials = provider.decrypt_credentials()
|
||||
|
||||
# Handle client registration if needed
|
||||
client_information = provider.client_information()
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
|
||||
# For client credentials flow, we don't need to register client dynamically
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Client should provide client_id and client_secret directly
|
||||
raise ValueError("Client credentials flow requires client_id and client_secret to be provided")
|
||||
|
||||
try:
|
||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
||||
except httpx.RequestError as e:
|
||||
full_information = register_client(server_url, server_metadata, client_metadata)
|
||||
except RequestError as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
provider.save_client_information(full_information)
|
||||
|
||||
# Return action to save client information
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CLIENT_INFO,
|
||||
data={"client_information": full_information.model_dump()},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
client_information = full_information
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
# Handle client credentials flow
|
||||
if effective_grant_type == MCPSupportGrantType.CLIENT_CREDENTIALS.value:
|
||||
# Direct token request without user interaction
|
||||
try:
|
||||
scope = credentials.get("scope")
|
||||
tokens = client_credentials_flow(
|
||||
server_url,
|
||||
server_metadata,
|
||||
client_information,
|
||||
scope,
|
||||
)
|
||||
|
||||
# Return action to save tokens and grant type
|
||||
token_data = tokens.model_dump()
|
||||
token_data["grant_type"] = MCPSupportGrantType.CLIENT_CREDENTIALS.value
|
||||
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=token_data,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Client credentials flow failed: {e}")
|
||||
|
||||
# Exchange authorization code for tokens (Authorization Code flow)
|
||||
if authorization_code is not None:
|
||||
if not state_param:
|
||||
raise ValueError("State parameter is required when exchanging authorization code")
|
||||
@ -335,35 +473,69 @@ def auth(
|
||||
|
||||
tokens = exchange_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
server_metadata,
|
||||
client_information,
|
||||
authorization_code,
|
||||
code_verifier,
|
||||
redirect_uri,
|
||||
)
|
||||
provider.save_tokens(tokens)
|
||||
return {"result": "success"}
|
||||
|
||||
provider_tokens = provider.tokens()
|
||||
# Return action to save tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
|
||||
provider_tokens = provider.retrieve_tokens()
|
||||
|
||||
# Handle token refresh or new authorization
|
||||
if provider_tokens and provider_tokens.refresh_token:
|
||||
try:
|
||||
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
|
||||
provider.save_tokens(new_tokens)
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
new_tokens = refresh_authorization(
|
||||
server_url, server_metadata, client_information, provider_tokens.refresh_token
|
||||
)
|
||||
|
||||
# Return action to save new tokens
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_TOKENS,
|
||||
data=new_tokens.model_dump(),
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"result": "success"})
|
||||
except (RequestError, ValueError, KeyError) as e:
|
||||
# RequestError: HTTP request failed
|
||||
# ValueError: Invalid response data
|
||||
# KeyError: Missing required fields in response
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow
|
||||
# Start new authorization flow (only for authorization code flow)
|
||||
authorization_url, code_verifier = start_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
server_metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
provider.mcp_provider.id,
|
||||
provider.mcp_provider.tenant_id,
|
||||
redirect_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
return {"authorization_url": authorization_url}
|
||||
# Return action to save code verifier
|
||||
actions.append(
|
||||
AuthAction(
|
||||
action_type=AuthActionType.SAVE_CODE_VERIFIER,
|
||||
data={"code_verifier": code_verifier},
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
return AuthResult(actions=actions, response={"authorization_url": authorization_url})
|
||||
|
||||
@ -1,77 +0,0 @@
|
||||
from configs import dify_config
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
||||
if for_list:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
else:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""The URL to redirect the user agent to after authorization."""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=[self.redirect_url],
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
)
|
||||
|
||||
def client_information(self) -> OAuthClientInformation | None:
|
||||
"""Loads information about this OAuth client."""
|
||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
||||
if not client_information:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_information)
|
||||
|
||||
def save_client_information(self, client_information: OAuthClientInformationFull):
|
||||
"""Saves client information after dynamic registration."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.mcp_provider,
|
||||
{"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def tokens(self) -> OAuthTokens | None:
|
||||
"""Loads any existing OAuth tokens for the current session."""
|
||||
credentials = self.mcp_provider.decrypted_credentials
|
||||
if not credentials:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", "Bearer"),
|
||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def save_tokens(self, tokens: OAuthTokens):
|
||||
"""Stores new OAuth tokens for the current session."""
|
||||
# update mcp provider credentials
|
||||
token_dict = tokens.model_dump()
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||
|
||||
def save_code_verifier(self, code_verifier: str):
|
||||
"""Saves a PKCE code verifier for the current session."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||
|
||||
def code_verifier(self) -> str:
|
||||
"""Loads the PKCE code verifier for the current session."""
|
||||
# get code verifier from mcp provider credentials
|
||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
||||
191
api/core/mcp/auth_client.py
Normal file
191
api/core/mcp/auth_client.py
Normal file
@ -0,0 +1,191 @@
|
||||
"""
|
||||
MCP Client with Authentication Retry Support
|
||||
|
||||
This module provides an enhanced MCPClient that automatically handles
|
||||
authentication failures and retries operations after refreshing tokens.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.error import MCPAuthError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClientWithAuthRetry(MCPClient):
|
||||
"""
|
||||
An enhanced MCPClient that provides automatic authentication retry.
|
||||
|
||||
This class extends MCPClient and intercepts MCPAuthError exceptions
|
||||
to refresh authentication before retrying failed operations.
|
||||
|
||||
Note: This class uses lazy session creation - database sessions are only
|
||||
created when authentication retry is actually needed, not on every request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
provider_entity: MCPProviderEntity | None = None,
|
||||
authorization_code: str | None = None,
|
||||
by_server_id: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the MCP client with auth retry capability.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL
|
||||
headers: Optional headers for requests
|
||||
timeout: Request timeout
|
||||
sse_read_timeout: SSE read timeout
|
||||
provider_entity: Provider entity for authentication
|
||||
authorization_code: Optional authorization code for initial auth
|
||||
by_server_id: Whether to look up provider by server ID
|
||||
"""
|
||||
super().__init__(server_url, headers, timeout, sse_read_timeout)
|
||||
|
||||
self.provider_entity = provider_entity
|
||||
self.authorization_code = authorization_code
|
||||
self.by_server_id = by_server_id
|
||||
self._has_retried = False
|
||||
|
||||
def _handle_auth_error(self, error: MCPAuthError) -> None:
|
||||
"""
|
||||
Handle authentication error by refreshing tokens.
|
||||
|
||||
This method creates a short-lived database session only when authentication
|
||||
retry is needed, minimizing database connection hold time.
|
||||
|
||||
Args:
|
||||
error: The authentication error
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails or max retries reached
|
||||
"""
|
||||
if not self.provider_entity:
|
||||
raise error
|
||||
if self._has_retried:
|
||||
raise error
|
||||
|
||||
self._has_retried = True
|
||||
|
||||
try:
|
||||
# Create a temporary session only for auth retry
|
||||
# This session is short-lived and only exists during the auth operation
|
||||
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
|
||||
# Perform authentication using the service's auth method
|
||||
mcp_service.auth_with_actions(self.provider_entity, self.authorization_code)
|
||||
|
||||
# Retrieve new tokens
|
||||
self.provider_entity = mcp_service.get_provider_entity(
|
||||
self.provider_entity.id, self.provider_entity.tenant_id, by_server_id=self.by_server_id
|
||||
)
|
||||
|
||||
# Session is closed here, before we update headers
|
||||
token = self.provider_entity.retrieve_tokens()
|
||||
if not token:
|
||||
raise MCPAuthError("Authentication failed - no token received")
|
||||
|
||||
# Update headers with new token
|
||||
self.headers["Authorization"] = f"{token.token_type.capitalize()} {token.access_token}"
|
||||
|
||||
# Clear authorization code after first use
|
||||
self.authorization_code = None
|
||||
|
||||
except MCPAuthError:
|
||||
# Re-raise MCPAuthError as is
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch all exceptions during auth retry
|
||||
logger.exception("Authentication retry failed")
|
||||
raise MCPAuthError(f"Authentication retry failed: {e}") from e
|
||||
|
||||
def _execute_with_retry(self, func: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
"""
|
||||
Execute a function with authentication retry logic.
|
||||
|
||||
Args:
|
||||
func: The function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
The result of the function call
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
Any other exceptions from the function
|
||||
"""
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except MCPAuthError as e:
|
||||
self._handle_auth_error(e)
|
||||
|
||||
# Re-initialize the connection with new headers
|
||||
if self._initialized:
|
||||
# Clean up existing connection
|
||||
self._exit_stack.close()
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
|
||||
# Re-initialize with new headers
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
# Reset retry flag after operation completes
|
||||
self._has_retried = False
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter the context manager with retry support."""
|
||||
|
||||
def initialize_with_retry():
|
||||
super(MCPClientWithAuthRetry, self).__enter__()
|
||||
return self
|
||||
|
||||
return self._execute_with_retry(initialize_with_retry)
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
List available tools from the MCP server with auth retry.
|
||||
|
||||
Returns:
|
||||
List of available tools
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
"""
|
||||
return self._execute_with_retry(super().list_tools)
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""
|
||||
Invoke a tool on the MCP server with auth retry.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to invoke
|
||||
tool_args: Arguments for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool invocation
|
||||
|
||||
Raises:
|
||||
MCPAuthError: If authentication fails after retries
|
||||
"""
|
||||
return self._execute_with_retry(super().invoke_tool, tool_name, tool_args)
|
||||
@ -46,7 +46,7 @@ class SSETransport:
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
sse_read_timeout: float = 1 * 60,
|
||||
):
|
||||
"""Initialize the SSE transport.
|
||||
|
||||
@ -255,7 +255,7 @@ def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
sse_read_timeout: float = 1 * 60,
|
||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||
"""
|
||||
Client transport for SSE.
|
||||
@ -276,31 +276,34 @@ def sse_client(
|
||||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
executor = ThreadPoolExecutor()
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
yield read_queue, write_queue
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
|
||||
|
||||
@ -434,45 +434,48 @@ def streamablehttp_client(
|
||||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream():
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
executor = ThreadPoolExecutor(max_workers=2)
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream():
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
||||
from pydantic import BaseModel
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAuthMetadata, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
@ -17,3 +20,41 @@ class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
|
||||
|
||||
class AuthActionType(StrEnum):
|
||||
"""Types of actions that can be performed during auth flow."""
|
||||
|
||||
SAVE_CLIENT_INFO = "save_client_info"
|
||||
SAVE_TOKENS = "save_tokens"
|
||||
SAVE_CODE_VERIFIER = "save_code_verifier"
|
||||
START_AUTHORIZATION = "start_authorization"
|
||||
SUCCESS = "success"
|
||||
|
||||
|
||||
class AuthAction(BaseModel):
|
||||
"""Represents an action that needs to be performed as a result of auth flow."""
|
||||
|
||||
action_type: AuthActionType
|
||||
data: dict[str, Any]
|
||||
provider_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class AuthResult(BaseModel):
|
||||
"""Result of auth function containing actions to be performed and response data."""
|
||||
|
||||
actions: list[AuthAction]
|
||||
response: dict[str, str]
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
"""State data stored in Redis during OAuth callback flow."""
|
||||
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
@ -8,3 +8,7 @@ class MCPConnectionError(MCPError):
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class MCPRefreshTokenError(MCPError):
|
||||
pass
|
||||
|
||||
@ -7,9 +7,9 @@ from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.session.client_session import ClientSession
|
||||
from core.mcp.types import Tool
|
||||
from core.mcp.types import CallToolResult, Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -18,40 +18,18 @@ class MCPClient:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
authed: bool = True,
|
||||
authorization_code: str | None = None,
|
||||
for_list: bool = False,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.client_type = "streamable"
|
||||
self.server_url = server_url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
|
||||
# Authentication info
|
||||
self.authed = authed
|
||||
self.authorization_code = authorization_code
|
||||
if authed:
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
||||
self.token = self.provider.tokens()
|
||||
|
||||
# Initialize session and client objects
|
||||
self._session: ClientSession | None = None
|
||||
self._streams_context: AbstractContextManager[Any] | None = None
|
||||
self._session_context: ClientSession | None = None
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
# Whether the client has been initialized
|
||||
self._initialized = False
|
||||
|
||||
def __enter__(self):
|
||||
@ -85,61 +63,42 @@ class MCPClient:
|
||||
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(
|
||||
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
|
||||
):
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
|
||||
"""
|
||||
Connect to the MCP server using streamable http or sse.
|
||||
Default to streamable http.
|
||||
Args:
|
||||
client_factory: The client factory to use(streamablehttp_client or sse_client).
|
||||
method_name: The method name to use(mcp or sse).
|
||||
"""
|
||||
streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else self.headers
|
||||
)
|
||||
self._streams_context = client_factory(
|
||||
url=self.server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
if not self._streams_context:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._exit_stack.enter_context(streams_context)
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._exit_stack.enter_context(self._streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._exit_stack.enter_context(self._streams_context)
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(self._session_context)
|
||||
self._session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
if not self.authed:
|
||||
raise
|
||||
try:
|
||||
auth(self.provider, self.server_url, self.authorization_code)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to authenticate: {e}")
|
||||
self.token = self.provider.tokens()
|
||||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
session_context = ClientSession(*streams)
|
||||
self._session = self._exit_stack.enter_context(session_context)
|
||||
self._session.initialize()
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
# List available tools to verify connection
|
||||
if not self._initialized or not self._session:
|
||||
"""List available tools from the MCP server"""
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
response = self._session.list_tools()
|
||||
tools = response.tools
|
||||
return tools
|
||||
return response.tools
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict):
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""Call a tool"""
|
||||
if not self._initialized or not self._session:
|
||||
if not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
return self._session.call_tool(tool_name, tool_args)
|
||||
|
||||
@ -153,6 +112,4 @@ class MCPClient:
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
finally:
|
||||
self._session = None
|
||||
self._session_context = None
|
||||
self._streams_context = None
|
||||
self._initialized = False
|
||||
|
||||
@ -201,11 +201,14 @@ class BaseSession(
|
||||
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
|
||||
except TimeoutError:
|
||||
# If the receiver loop is still running after timeout, we'll force shutdown
|
||||
pass
|
||||
# Cancel the future to interrupt the receiver loop
|
||||
self._receiver_future.cancel()
|
||||
|
||||
# Shutdown the executor
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=True)
|
||||
# Use non-blocking shutdown to prevent hanging
|
||||
# The receiver thread should have already exited due to the None message in the queue
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
|
||||
@ -109,12 +109,16 @@ class ClientSession(
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability()
|
||||
roots = types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
# Only set capabilities if non-default callbacks are provided
|
||||
# This prevents servers from attempting callbacks when we don't actually support them
|
||||
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
|
||||
roots = (
|
||||
types.RootsCapability(
|
||||
# Only enable listChanged if we have a custom callback
|
||||
listChanged=True,
|
||||
)
|
||||
if self._list_roots_callback is not _default_list_roots_callback
|
||||
else None
|
||||
)
|
||||
|
||||
result = self.send_request(
|
||||
@ -284,7 +288,7 @@ class ClientSession(
|
||||
|
||||
def complete(
|
||||
self,
|
||||
ref: types.ResourceReference | types.PromptReference,
|
||||
ref: types.ResourceTemplateReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
|
||||
@ -1,13 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeAlias,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
||||
from pydantic.networks import AnyUrl, UrlConstraints
|
||||
@ -33,6 +26,7 @@ for reference.
|
||||
LATEST_PROTOCOL_VERSION = "2025-03-26"
|
||||
# Server support 2024-11-05 to allow claude to use.
|
||||
SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
|
||||
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"
|
||||
ProgressToken = str | int
|
||||
Cursor = str
|
||||
Role = Literal["user", "assistant"]
|
||||
@ -55,14 +49,22 @@ class RequestParams(BaseModel):
|
||||
meta: Meta | None = Field(alias="_meta", default=None)
|
||||
|
||||
|
||||
class PaginatedRequestParams(RequestParams):
|
||||
cursor: Cursor | None = None
|
||||
"""
|
||||
An opaque token representing the current pagination position.
|
||||
If provided, the server should return results starting after this cursor.
|
||||
"""
|
||||
|
||||
|
||||
class NotificationParams(BaseModel):
|
||||
class Meta(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
meta: Meta | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
This parameter name is reserved by MCP to allow clients and servers to attach
|
||||
additional metadata to their notifications.
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
|
||||
|
||||
@ -79,12 +81,11 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedRequest(Request[RequestParamsT, MethodT]):
|
||||
cursor: Cursor | None = None
|
||||
"""
|
||||
An opaque token representing the current pagination position.
|
||||
If provided, the server should return results starting after this cursor.
|
||||
"""
|
||||
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
|
||||
"""Base class for paginated requests,
|
||||
matching the schema's PaginatedRequest interface."""
|
||||
|
||||
params: PaginatedRequestParams | None = None
|
||||
|
||||
|
||||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||
@ -98,13 +99,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||
class Result(BaseModel):
|
||||
"""Base class for JSON-RPC results."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
This result property is reserved by the protocol to allow clients and servers to
|
||||
attach additional metadata to their responses.
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedResult(Result):
|
||||
@ -186,10 +186,26 @@ class EmptyResult(Result):
|
||||
"""A response that indicates success but carries no data."""
|
||||
|
||||
|
||||
class Implementation(BaseModel):
|
||||
"""Describes the name and version of an MCP implementation."""
|
||||
class BaseMetadata(BaseModel):
|
||||
"""Base class for entities with name and optional title fields."""
|
||||
|
||||
name: str
|
||||
"""The programmatic name of the entity."""
|
||||
|
||||
title: str | None = None
|
||||
"""
|
||||
Intended for UI and end-user contexts — optimized to be human-readable and easily understood,
|
||||
even by those unfamiliar with domain-specific terminology.
|
||||
|
||||
If not provided, the name should be used for display (except for Tool,
|
||||
where `annotations.title` should be given precedence over using `name`,
|
||||
if present).
|
||||
"""
|
||||
|
||||
|
||||
class Implementation(BaseMetadata):
|
||||
"""Describes the name and version of an MCP implementation."""
|
||||
|
||||
version: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@ -203,7 +219,7 @@ class RootsCapability(BaseModel):
|
||||
|
||||
|
||||
class SamplingCapability(BaseModel):
|
||||
"""Capability for logging operations."""
|
||||
"""Capability for sampling operations."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@ -252,6 +268,12 @@ class LoggingCapability(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompletionsCapability(BaseModel):
|
||||
"""Capability for completions operations."""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ServerCapabilities(BaseModel):
|
||||
"""Capabilities that a server may support."""
|
||||
|
||||
@ -265,6 +287,8 @@ class ServerCapabilities(BaseModel):
|
||||
"""Present if the server offers any resources to read."""
|
||||
tools: ToolsCapability | None = None
|
||||
"""Present if the server offers any tools to call."""
|
||||
completions: CompletionsCapability | None = None
|
||||
"""Present if the server offers autocompletion suggestions for prompts and resources."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -284,7 +308,7 @@ class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]])
|
||||
to begin initialization.
|
||||
"""
|
||||
|
||||
method: Literal["initialize"]
|
||||
method: Literal["initialize"] = "initialize"
|
||||
params: InitializeRequestParams
|
||||
|
||||
|
||||
@ -305,7 +329,7 @@ class InitializedNotification(Notification[NotificationParams | None, Literal["n
|
||||
finished.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/initialized"]
|
||||
method: Literal["notifications/initialized"] = "notifications/initialized"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@ -315,7 +339,7 @@ class PingRequest(Request[RequestParams | None, Literal["ping"]]):
|
||||
still alive.
|
||||
"""
|
||||
|
||||
method: Literal["ping"]
|
||||
method: Literal["ping"] = "ping"
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
@ -334,6 +358,11 @@ class ProgressNotificationParams(NotificationParams):
|
||||
"""
|
||||
total: float | None = None
|
||||
"""Total number of items to process (or total progress required), if known."""
|
||||
message: str | None = None
|
||||
"""
|
||||
Message related to progress. This should provide relevant human readable
|
||||
progress information.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -343,15 +372,14 @@ class ProgressNotification(Notification[ProgressNotificationParams, Literal["not
|
||||
long-running request.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/progress"]
|
||||
method: Literal["notifications/progress"] = "notifications/progress"
|
||||
params: ProgressNotificationParams
|
||||
|
||||
|
||||
class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]):
|
||||
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
|
||||
"""Sent from the client to request a list of resources the server has."""
|
||||
|
||||
method: Literal["resources/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["resources/list"] = "resources/list"
|
||||
|
||||
|
||||
class Annotations(BaseModel):
|
||||
@ -360,13 +388,11 @@ class Annotations(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
class Resource(BaseMetadata):
|
||||
"""A known resource that the server is capable of reading."""
|
||||
|
||||
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
|
||||
"""The URI of this resource."""
|
||||
name: str
|
||||
"""A human-readable name for this resource."""
|
||||
description: str | None = None
|
||||
"""A description of what this resource represents."""
|
||||
mimeType: str | None = None
|
||||
@ -379,10 +405,15 @@ class Resource(BaseModel):
|
||||
This can be used by Hosts to display file sizes and estimate context window usage.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourceTemplate(BaseModel):
|
||||
class ResourceTemplate(BaseMetadata):
|
||||
"""A template description for resources available on the server."""
|
||||
|
||||
uriTemplate: str
|
||||
@ -390,8 +421,6 @@ class ResourceTemplate(BaseModel):
|
||||
A URI template (according to RFC 6570) that can be used to construct resource
|
||||
URIs.
|
||||
"""
|
||||
name: str
|
||||
"""A human-readable name for the type of resource this template refers to."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of what this template is for."""
|
||||
mimeType: str | None = None
|
||||
@ -400,6 +429,11 @@ class ResourceTemplate(BaseModel):
|
||||
included if all resources matching this template have the same type.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -409,11 +443,10 @@ class ListResourcesResult(PaginatedResult):
|
||||
resources: list[Resource]
|
||||
|
||||
|
||||
class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]):
|
||||
class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]):
|
||||
"""Sent from the client to request a list of resource templates the server has."""
|
||||
|
||||
method: Literal["resources/templates/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["resources/templates/list"] = "resources/templates/list"
|
||||
|
||||
|
||||
class ListResourceTemplatesResult(PaginatedResult):
|
||||
@ -436,7 +469,7 @@ class ReadResourceRequestParams(RequestParams):
|
||||
class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
|
||||
"""Sent from the client to the server, to read a specific resource URI."""
|
||||
|
||||
method: Literal["resources/read"]
|
||||
method: Literal["resources/read"] = "resources/read"
|
||||
params: ReadResourceRequestParams
|
||||
|
||||
|
||||
@ -447,6 +480,11 @@ class ResourceContents(BaseModel):
|
||||
"""The URI of this resource."""
|
||||
mimeType: str | None = None
|
||||
"""The MIME type of this resource, if known."""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -481,7 +519,7 @@ class ResourceListChangedNotification(
|
||||
of resources it can read from has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/resources/list_changed"]
|
||||
method: Literal["notifications/resources/list_changed"] = "notifications/resources/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@ -502,7 +540,7 @@ class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscr
|
||||
whenever a particular resource changes.
|
||||
"""
|
||||
|
||||
method: Literal["resources/subscribe"]
|
||||
method: Literal["resources/subscribe"] = "resources/subscribe"
|
||||
params: SubscribeRequestParams
|
||||
|
||||
|
||||
@ -520,7 +558,7 @@ class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/un
|
||||
the server.
|
||||
"""
|
||||
|
||||
method: Literal["resources/unsubscribe"]
|
||||
method: Literal["resources/unsubscribe"] = "resources/unsubscribe"
|
||||
params: UnsubscribeRequestParams
|
||||
|
||||
|
||||
@ -543,15 +581,14 @@ class ResourceUpdatedNotification(
|
||||
changed and may need to be read again.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/resources/updated"]
|
||||
method: Literal["notifications/resources/updated"] = "notifications/resources/updated"
|
||||
params: ResourceUpdatedNotificationParams
|
||||
|
||||
|
||||
class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]):
|
||||
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
|
||||
"""Sent from the client to request a list of prompts and prompt templates."""
|
||||
|
||||
method: Literal["prompts/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["prompts/list"] = "prompts/list"
|
||||
|
||||
|
||||
class PromptArgument(BaseModel):
|
||||
@ -566,15 +603,18 @@ class PromptArgument(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
class Prompt(BaseMetadata):
|
||||
"""A prompt or prompt template that the server offers."""
|
||||
|
||||
name: str
|
||||
"""The name of the prompt or prompt template."""
|
||||
description: str | None = None
|
||||
"""An optional description of what this prompt provides."""
|
||||
arguments: list[PromptArgument] | None = None
|
||||
"""A list of arguments to use for templating the prompt."""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -597,7 +637,7 @@ class GetPromptRequestParams(RequestParams):
|
||||
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
|
||||
"""Used by the client to get a prompt provided by the server."""
|
||||
|
||||
method: Literal["prompts/get"]
|
||||
method: Literal["prompts/get"] = "prompts/get"
|
||||
params: GetPromptRequestParams
|
||||
|
||||
|
||||
@ -608,6 +648,11 @@ class TextContent(BaseModel):
|
||||
text: str
|
||||
"""The text content of the message."""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -623,6 +668,31 @@ class ImageContent(BaseModel):
|
||||
image types.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class AudioContent(BaseModel):
|
||||
"""Audio content for a message."""
|
||||
|
||||
type: Literal["audio"]
|
||||
data: str
|
||||
"""The base64-encoded audio data."""
|
||||
mimeType: str
|
||||
"""
|
||||
The MIME type of the audio. Different providers may support different
|
||||
audio types.
|
||||
"""
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -630,7 +700,7 @@ class SamplingMessage(BaseModel):
|
||||
"""Describes a message issued to or received from an LLM API."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
content: TextContent | ImageContent | AudioContent
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -645,14 +715,36 @@ class EmbeddedResource(BaseModel):
|
||||
type: Literal["resource"]
|
||||
resource: TextResourceContents | BlobResourceContents
|
||||
annotations: Annotations | None = None
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class ResourceLink(Resource):
|
||||
"""
|
||||
A resource that the server is capable of reading, included in a prompt or tool call result.
|
||||
|
||||
Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests.
|
||||
"""
|
||||
|
||||
type: Literal["resource_link"]
|
||||
|
||||
|
||||
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
|
||||
"""A content block that can be used in prompts and tool results."""
|
||||
|
||||
Content: TypeAlias = ContentBlock
|
||||
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
"""Describes a message returned as part of a prompt."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent | EmbeddedResource
|
||||
content: ContentBlock
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -672,15 +764,14 @@ class PromptListChangedNotification(
|
||||
of prompts it offers has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/prompts/list_changed"]
|
||||
method: Literal["notifications/prompts/list_changed"] = "notifications/prompts/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
|
||||
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
|
||||
"""Sent from the client to request a list of tools the server has."""
|
||||
|
||||
method: Literal["tools/list"]
|
||||
params: RequestParams | None = None
|
||||
method: Literal["tools/list"] = "tools/list"
|
||||
|
||||
|
||||
class ToolAnnotations(BaseModel):
|
||||
@ -731,17 +822,25 @@ class ToolAnnotations(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
class Tool(BaseMetadata):
|
||||
"""Definition for a tool the client can call."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
description: str | None = None
|
||||
"""A human-readable description of the tool."""
|
||||
inputSchema: dict[str, Any]
|
||||
"""A JSON Schema object defining the expected parameters for the tool."""
|
||||
outputSchema: dict[str, Any] | None = None
|
||||
"""
|
||||
An optional JSON Schema object defining the structure of the tool's output
|
||||
returned in the structuredContent field of a CallToolResult.
|
||||
"""
|
||||
annotations: ToolAnnotations | None = None
|
||||
"""Optional additional tool information."""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -762,14 +861,16 @@ class CallToolRequestParams(RequestParams):
|
||||
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
|
||||
"""Used by the client to invoke a tool provided by the server."""
|
||||
|
||||
method: Literal["tools/call"]
|
||||
method: Literal["tools/call"] = "tools/call"
|
||||
params: CallToolRequestParams
|
||||
|
||||
|
||||
class CallToolResult(Result):
|
||||
"""The server's response to a tool call."""
|
||||
|
||||
content: list[TextContent | ImageContent | EmbeddedResource]
|
||||
content: list[ContentBlock]
|
||||
structuredContent: dict[str, Any] | None = None
|
||||
"""An optional JSON object that represents the structured result of the tool call."""
|
||||
isError: bool = False
|
||||
|
||||
|
||||
@ -779,7 +880,7 @@ class ToolListChangedNotification(Notification[NotificationParams | None, Litera
|
||||
of tools it offers has changed.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/tools/list_changed"]
|
||||
method: Literal["notifications/tools/list_changed"] = "notifications/tools/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@ -797,7 +898,7 @@ class SetLevelRequestParams(RequestParams):
|
||||
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
|
||||
"""A request from the client to the server, to enable or adjust logging."""
|
||||
|
||||
method: Literal["logging/setLevel"]
|
||||
method: Literal["logging/setLevel"] = "logging/setLevel"
|
||||
params: SetLevelRequestParams
|
||||
|
||||
|
||||
@ -808,7 +909,7 @@ class LoggingMessageNotificationParams(NotificationParams):
|
||||
"""The severity of this log message."""
|
||||
logger: str | None = None
|
||||
"""An optional name of the logger issuing this message."""
|
||||
data: Any = None
|
||||
data: Any
|
||||
"""
|
||||
The data to be logged, such as a string message or an object. Any JSON serializable
|
||||
type is allowed here.
|
||||
@ -819,7 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams):
|
||||
class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
|
||||
"""Notification of a log message passed from server to client."""
|
||||
|
||||
method: Literal["notifications/message"]
|
||||
method: Literal["notifications/message"] = "notifications/message"
|
||||
params: LoggingMessageNotificationParams
|
||||
|
||||
|
||||
@ -914,7 +1015,7 @@ class CreateMessageRequestParams(RequestParams):
|
||||
class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
|
||||
"""A request from the server to sample an LLM via the client."""
|
||||
|
||||
method: Literal["sampling/createMessage"]
|
||||
method: Literal["sampling/createMessage"] = "sampling/createMessage"
|
||||
params: CreateMessageRequestParams
|
||||
|
||||
|
||||
@ -925,14 +1026,14 @@ class CreateMessageResult(Result):
|
||||
"""The client's response to a sampling/create_message request from the server."""
|
||||
|
||||
role: Role
|
||||
content: TextContent | ImageContent
|
||||
content: TextContent | ImageContent | AudioContent
|
||||
model: str
|
||||
"""The name of the model that generated the message."""
|
||||
stopReason: StopReason | None = None
|
||||
"""The reason why sampling stopped, if known."""
|
||||
|
||||
|
||||
class ResourceReference(BaseModel):
|
||||
class ResourceTemplateReference(BaseModel):
|
||||
"""A reference to a resource or resource template definition."""
|
||||
|
||||
type: Literal["ref/resource"]
|
||||
@ -960,18 +1061,28 @@ class CompletionArgument(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompletionContext(BaseModel):
|
||||
"""Additional, optional context for completions."""
|
||||
|
||||
arguments: dict[str, str] | None = None
|
||||
"""Previously-resolved variables in a URI template or prompt."""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompleteRequestParams(RequestParams):
|
||||
"""Parameters for completion requests."""
|
||||
|
||||
ref: ResourceReference | PromptReference
|
||||
ref: ResourceTemplateReference | PromptReference
|
||||
argument: CompletionArgument
|
||||
context: CompletionContext | None = None
|
||||
"""Additional, optional context for completions"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
|
||||
"""A request from the client to the server, to ask for completion options."""
|
||||
|
||||
method: Literal["completion/complete"]
|
||||
method: Literal["completion/complete"] = "completion/complete"
|
||||
params: CompleteRequestParams
|
||||
|
||||
|
||||
@ -1010,7 +1121,7 @@ class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
|
||||
structure or access specific locations that the client has permission to read from.
|
||||
"""
|
||||
|
||||
method: Literal["roots/list"]
|
||||
method: Literal["roots/list"] = "roots/list"
|
||||
params: RequestParams | None = None
|
||||
|
||||
|
||||
@ -1029,6 +1140,11 @@ class Root(BaseModel):
|
||||
identifier for the root, which may be useful for display purposes or for
|
||||
referencing the root in other parts of the application.
|
||||
"""
|
||||
meta: dict[str, Any] | None = Field(alias="_meta", default=None)
|
||||
"""
|
||||
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
|
||||
for notes on _meta usage.
|
||||
"""
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@ -1054,7 +1170,7 @@ class RootsListChangedNotification(
|
||||
using the ListRootsRequest.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/roots/list_changed"]
|
||||
method: Literal["notifications/roots/list_changed"] = "notifications/roots/list_changed"
|
||||
params: NotificationParams | None = None
|
||||
|
||||
|
||||
@ -1074,7 +1190,7 @@ class CancelledNotification(Notification[CancelledNotificationParams, Literal["n
|
||||
previously-issued request.
|
||||
"""
|
||||
|
||||
method: Literal["notifications/cancelled"]
|
||||
method: Literal["notifications/cancelled"] = "notifications/cancelled"
|
||||
params: CancelledNotificationParams
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
@ -18,7 +19,9 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
from models.workflow import Workflow, WorkflowRun
|
||||
from models.workflow import Workflow
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
@ -29,6 +32,14 @@ class TokenBufferMemory:
|
||||
):
|
||||
self.conversation = conversation
|
||||
self.model_instance = model_instance
|
||||
self._workflow_run_repo: APIWorkflowRunRepository | None = None
|
||||
|
||||
@property
|
||||
def workflow_run_repo(self) -> APIWorkflowRunRepository:
|
||||
if self._workflow_run_repo is None:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return self._workflow_run_repo
|
||||
|
||||
def _build_prompt_message_with_files(
|
||||
self,
|
||||
@ -50,7 +61,16 @@ class TokenBufferMemory:
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow_run = db.session.scalar(select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id))
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
|
||||
if not message.workflow_run_id:
|
||||
raise ValueError("Workflow run ID not found")
|
||||
|
||||
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
|
||||
)
|
||||
if not workflow_run:
|
||||
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
|
||||
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
|
||||
|
||||
@ -38,6 +38,8 @@ class LLMUsageMetadata(TypedDict, total=False):
|
||||
prompt_price: Union[float, str]
|
||||
completion_price: Union[float, str]
|
||||
latency: float
|
||||
time_to_first_token: float
|
||||
time_to_generate: float
|
||||
|
||||
|
||||
class LLMUsage(ModelUsage):
|
||||
@ -57,6 +59,8 @@ class LLMUsage(ModelUsage):
|
||||
total_price: Decimal
|
||||
currency: str
|
||||
latency: float
|
||||
time_to_first_token: float | None = None
|
||||
time_to_generate: float | None = None
|
||||
|
||||
@classmethod
|
||||
def empty_usage(cls):
|
||||
@ -73,6 +77,8 @@ class LLMUsage(ModelUsage):
|
||||
total_price=Decimal("0.0"),
|
||||
currency="USD",
|
||||
latency=0.0,
|
||||
time_to_first_token=None,
|
||||
time_to_generate=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -108,6 +114,8 @@ class LLMUsage(ModelUsage):
|
||||
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
|
||||
completion_price=Decimal(str(metadata.get("completion_price", 0))),
|
||||
latency=metadata.get("latency", 0.0),
|
||||
time_to_first_token=metadata.get("time_to_first_token"),
|
||||
time_to_generate=metadata.get("time_to_generate"),
|
||||
)
|
||||
|
||||
def plus(self, other: LLMUsage) -> LLMUsage:
|
||||
@ -133,6 +141,8 @@ class LLMUsage(ModelUsage):
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency,
|
||||
time_to_first_token=other.time_to_first_token,
|
||||
time_to_generate=other.time_to_generate,
|
||||
)
|
||||
|
||||
def __add__(self, other: LLMUsage) -> LLMUsage:
|
||||
|
||||
@ -1,21 +1,22 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
from openinference.semconv.trace import OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
from sqlalchemy import select
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
@ -30,9 +31,10 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -99,22 +101,45 @@ def datetime_to_nanos(dt: datetime | None) -> int:
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def string_to_trace_id128(string: str | None) -> int:
|
||||
"""
|
||||
Convert any input string into a stable 128-bit integer trace ID.
|
||||
def error_to_string(error: Exception | str | None) -> str:
|
||||
"""Convert an error to a string with traceback information."""
|
||||
error_message = "Empty Stack Trace"
|
||||
if error:
|
||||
if isinstance(error, Exception):
|
||||
string_stacktrace = "".join(traceback.format_exception(error))
|
||||
error_message = f"{error.__class__.__name__}: {error}\n\n{string_stacktrace}"
|
||||
else:
|
||||
error_message = str(error)
|
||||
return error_message
|
||||
|
||||
This uses SHA-256 hashing and takes the first 16 bytes (128 bits) of the digest.
|
||||
It's suitable for generating consistent, unique identifiers from strings.
|
||||
"""
|
||||
if string is None:
|
||||
string = ""
|
||||
hash_object = hashlib.sha256(string.encode())
|
||||
|
||||
# Take the first 16 bytes (128 bits) of the hash digest
|
||||
digest = hash_object.digest()[:16]
|
||||
def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
"""Set the status of the current span based on the presence of an error."""
|
||||
if error:
|
||||
error_string = error_to_string(error)
|
||||
current_span.set_status(Status(StatusCode.ERROR, error_string))
|
||||
|
||||
# Convert to a 128-bit integer
|
||||
return int.from_bytes(digest, byteorder="big")
|
||||
if isinstance(error, Exception):
|
||||
current_span.record_exception(error)
|
||||
else:
|
||||
exception_type = error.__class__.__name__
|
||||
exception_message = str(error)
|
||||
if not exception_message:
|
||||
exception_message = repr(error)
|
||||
attributes: dict[str, AttributeValue] = {
|
||||
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
|
||||
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
|
||||
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
|
||||
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
|
||||
}
|
||||
current_span.add_event(name="exception", attributes=attributes)
|
||||
else:
|
||||
current_span.set_status(Status(StatusCode.OK))
|
||||
|
||||
|
||||
def safe_json_dumps(obj: Any) -> str:
|
||||
"""A convenience wrapper around `json.dumps` that ensures that any object can be safely encoded."""
|
||||
return json.dumps(obj, default=str, ensure_ascii=False)
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
@ -131,9 +156,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
self.propagator = TraceContextTextMapPropagator()
|
||||
self.dify_trace_ids: set[str] = set()
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info("[Arize/Phoenix] Trace: %s", trace_info)
|
||||
logger.info("[Arize/Phoenix] Trace Entity Info: %s", trace_info)
|
||||
logger.info("[Arize/Phoenix] Trace Entity Type: %s", type(trace_info))
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
@ -151,7 +179,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True)
|
||||
logger.error("[Arize/Phoenix] Trace Entity Error: %s", str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
@ -166,15 +194,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.workflow_run_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
@ -186,31 +208,58 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
# Through workflow_run_id, get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
|
||||
# Find the app's creator account
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
service_account = self.get_service_account_with_tenant(app_id)
|
||||
|
||||
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Process workflow nodes
|
||||
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
for node_execution in workflow_node_executions:
|
||||
tenant_id = trace_info.tenant_id # Use from trace_info instead
|
||||
app_id = trace_info.metadata.get("app_id") # Use from trace_info instead
|
||||
inputs_value = node_execution.inputs or {}
|
||||
outputs_value = node_execution.outputs or {}
|
||||
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
process_data = node_execution.process_data or {}
|
||||
execution_metadata = node_execution.metadata or {}
|
||||
node_metadata = {str(k): v for k, v in execution_metadata.items()}
|
||||
|
||||
node_metadata = {
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": node_execution.tenant_id,
|
||||
"app_id": node_execution.app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
|
||||
}
|
||||
|
||||
if node_execution.execution_metadata:
|
||||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
node_metadata.update(
|
||||
{
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status == "failed" else "DEFAULT",
|
||||
}
|
||||
)
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
@ -223,8 +272,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
@ -236,17 +286,20 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
workflow_span_context = set_span_in_context(workflow_span)
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.INPUT_VALUE: safe_json_dumps(inputs_value),
|
||||
SpanAttributes.INPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OUTPUT_VALUE: safe_json_dumps(outputs_value),
|
||||
SpanAttributes.OUTPUT_MIME_TYPE: OpenInferenceMimeTypeValues.JSON.value,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind.value,
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.METADATA: safe_json_dumps(node_metadata),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=workflow_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -260,11 +313,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
|
||||
if model:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
|
||||
outputs = (
|
||||
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
|
||||
)
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs_value.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
|
||||
@ -275,8 +325,16 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
|
||||
node_span.set_attributes(llm_attributes)
|
||||
finally:
|
||||
if node_execution.status == "failed":
|
||||
set_span_status(node_span, node_execution.error)
|
||||
else:
|
||||
set_span_status(node_span)
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(workflow_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(workflow_span)
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
@ -322,34 +380,18 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.trace_id or trace_info.message_id)
|
||||
message_span_id = RandomIdGenerator().generate_span_id()
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=message_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
message_span = self.tracer.start_span(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
attributes=attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
message_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
|
||||
# Convert outputs to string based on type
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
@ -383,26 +425,26 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
message_span_context = set_span_in_context(message_span)
|
||||
llm_span = self.tracer.start_span(
|
||||
name="llm",
|
||||
attributes=llm_attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
context=message_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
llm_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
if trace_info.message_data.error:
|
||||
set_span_status(llm_span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(llm_span)
|
||||
finally:
|
||||
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
finally:
|
||||
if trace_info.error:
|
||||
set_span_status(message_span, trace_info.error)
|
||||
else:
|
||||
set_span_status(message_span)
|
||||
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
@ -418,15 +460,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
@ -445,19 +481,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
@ -480,15 +511,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
@ -499,19 +524,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
@ -533,15 +553,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
@ -554,19 +568,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"end_time": end_time.isoformat() if end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
@ -580,20 +589,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||
logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id)
|
||||
|
||||
# Create span context with the same trace_id as the parent
|
||||
# todo: Create with the appropriate parent span context, so that the tool span is
|
||||
# a child of the appropriate span (e.g. message span)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=tool_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
@ -612,19 +610,14 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
@ -641,15 +634,9 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = string_to_trace_id128(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
dify_trace_id = trace_info.trace_id or trace_info.message_id or trace_info.conversation_id
|
||||
self.ensure_root_span(dify_trace_id)
|
||||
root_span_context = self.propagator.extract(carrier=self.carrier)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
@ -663,22 +650,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
context=root_span_context,
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
set_span_status(span, trace_info.message_data.error)
|
||||
else:
|
||||
set_span_status(span)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def ensure_root_span(self, dify_trace_id: str | None):
|
||||
"""Ensure a unique root span exists for the given Dify trace ID."""
|
||||
if str(dify_trace_id) not in self.dify_trace_ids:
|
||||
self.carrier: dict[str, str] = {}
|
||||
|
||||
root_span = self.tracer.start_span(name="Dify")
|
||||
root_span.set_attribute(SpanAttributes.OPENINFERENCE_SPAN_KIND, OpenInferenceSpanKindValues.CHAIN.value)
|
||||
root_span.set_attribute("dify_project_name", str(self.project))
|
||||
root_span.set_attribute("dify_trace_id", str(dify_trace_id))
|
||||
|
||||
with use_span(root_span, end_on_exit=False):
|
||||
self.propagator.inject(carrier=self.carrier)
|
||||
|
||||
set_span_status(root_span)
|
||||
root_span.end()
|
||||
self.dify_trace_ids.add(str(dify_trace_id))
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
with self.tracer.start_span("api_check") as span:
|
||||
@ -698,26 +697,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = db.session.scalars(
|
||||
select(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
).where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
).all()
|
||||
return workflow_nodes
|
||||
|
||||
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
|
||||
"""Helper method to construct LLM attributes with passed prompts."""
|
||||
attributes = {}
|
||||
|
||||
@ -62,6 +62,9 @@ class MessageTraceInfo(BaseTraceInfo):
|
||||
file_list: Union[str, dict[str, Any], list] | None = None
|
||||
message_file_data: Any | None = None
|
||||
conversation_mode: str
|
||||
gen_ai_server_time_to_first_token: float | None = None
|
||||
llm_streaming_time_to_generate: float | None = None
|
||||
is_streaming_request: bool = False
|
||||
|
||||
|
||||
class ModerationTraceInfo(BaseTraceInfo):
|
||||
|
||||
@ -12,9 +12,9 @@ from uuid import UUID, uuid4
|
||||
from cachetools import LRUCache
|
||||
from flask import current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
TracingProviderEnum,
|
||||
@ -34,7 +34,8 @@ from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog, WorkflowRun
|
||||
from models.workflow import WorkflowAppLog
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -140,6 +141,8 @@ provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
class OpsTraceManager:
|
||||
ops_trace_instances_cache: LRUCache = LRUCache(maxsize=128)
|
||||
decrypted_configs_cache: LRUCache = LRUCache(maxsize=128)
|
||||
_decryption_cache_lock = threading.RLock()
|
||||
|
||||
@classmethod
|
||||
def encrypt_tracing_config(
|
||||
@ -160,7 +163,7 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
|
||||
new_config = {}
|
||||
new_config: dict[str, Any] = {}
|
||||
# Encrypt necessary keys
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
@ -190,20 +193,41 @@ class OpsTraceManager:
|
||||
:param tracing_config: tracing config
|
||||
:return:
|
||||
"""
|
||||
config_class, secret_keys, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
config_json = json.dumps(tracing_config, sort_keys=True)
|
||||
decrypted_config_key = (
|
||||
tenant_id,
|
||||
tracing_provider,
|
||||
config_json,
|
||||
)
|
||||
new_config = {}
|
||||
for key in secret_keys:
|
||||
if key in tracing_config:
|
||||
new_config[key] = decrypt_token(tenant_id, tracing_config[key])
|
||||
|
||||
for key in other_keys:
|
||||
new_config[key] = tracing_config.get(key, "")
|
||||
# First check without lock for performance
|
||||
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
|
||||
if cached_config is not None:
|
||||
return dict(cached_config)
|
||||
|
||||
return config_class(**new_config).model_dump()
|
||||
with cls._decryption_cache_lock:
|
||||
# Second check (double-checked locking) to prevent race conditions
|
||||
cached_config = cls.decrypted_configs_cache.get(decrypted_config_key)
|
||||
if cached_config is not None:
|
||||
return dict(cached_config)
|
||||
|
||||
config_class, secret_keys, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
new_config: dict[str, Any] = {}
|
||||
keys_to_decrypt = [key for key in secret_keys if key in tracing_config]
|
||||
if keys_to_decrypt:
|
||||
decrypted_values = batch_decrypt_token(tenant_id, [tracing_config[key] for key in keys_to_decrypt])
|
||||
new_config.update(zip(keys_to_decrypt, decrypted_values))
|
||||
|
||||
for key in other_keys:
|
||||
new_config[key] = tracing_config.get(key, "")
|
||||
|
||||
decrypted_config = config_class(**new_config).model_dump()
|
||||
cls.decrypted_configs_cache[decrypted_config_key] = decrypted_config
|
||||
return dict(decrypted_config)
|
||||
|
||||
@classmethod
|
||||
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
|
||||
@ -218,7 +242,7 @@ class OpsTraceManager:
|
||||
provider_config_map[tracing_provider]["secret_keys"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
new_config = {}
|
||||
new_config: dict[str, Any] = {}
|
||||
for key in secret_keys:
|
||||
if key in decrypt_tracing_config:
|
||||
new_config[key] = obfuscated_token(decrypt_tracing_config[key])
|
||||
@ -419,6 +443,18 @@ class OpsTraceManager:
|
||||
|
||||
|
||||
class TraceTask:
|
||||
_workflow_run_repo = None
|
||||
_repo_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_run_repo(cls):
|
||||
if cls._workflow_run_repo is None:
|
||||
with cls._repo_lock:
|
||||
if cls._workflow_run_repo is None:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return cls._workflow_run_repo
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_type: Any,
|
||||
@ -486,27 +522,27 @@ class TraceTask:
|
||||
if not workflow_run_id:
|
||||
return {}
|
||||
|
||||
workflow_run_repo = self._get_workflow_run_repo()
|
||||
workflow_run = workflow_run_repo.get_workflow_run_by_id_without_tenant(run_id=workflow_run_id)
|
||||
if not workflow_run:
|
||||
raise ValueError("Workflow run not found")
|
||||
|
||||
workflow_id = workflow_run.workflow_id
|
||||
tenant_id = workflow_run.tenant_id
|
||||
workflow_run_id = workflow_run.id
|
||||
workflow_run_elapsed_time = workflow_run.elapsed_time
|
||||
workflow_run_status = workflow_run.status
|
||||
workflow_run_inputs = workflow_run.inputs_dict
|
||||
workflow_run_outputs = workflow_run.outputs_dict
|
||||
workflow_run_version = workflow_run.version
|
||||
error = workflow_run.error or ""
|
||||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") or []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
with Session(db.engine) as session:
|
||||
workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalars(workflow_run_stmt).first()
|
||||
if not workflow_run:
|
||||
raise ValueError("Workflow run not found")
|
||||
|
||||
workflow_id = workflow_run.workflow_id
|
||||
tenant_id = workflow_run.tenant_id
|
||||
workflow_run_id = workflow_run.id
|
||||
workflow_run_elapsed_time = workflow_run.elapsed_time
|
||||
workflow_run_status = workflow_run.status
|
||||
workflow_run_inputs = workflow_run.inputs_dict
|
||||
workflow_run_outputs = workflow_run.outputs_dict
|
||||
workflow_run_version = workflow_run.version
|
||||
error = workflow_run.error or ""
|
||||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") or []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
|
||||
WorkflowAppLog.tenant_id == tenant_id,
|
||||
@ -523,43 +559,43 @@ class TraceTask:
|
||||
)
|
||||
message_id = session.scalar(message_data_stmt)
|
||||
|
||||
metadata = {
|
||||
"workflow_id": workflow_id,
|
||||
"conversation_id": conversation_id,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"tenant_id": tenant_id,
|
||||
"elapsed_time": workflow_run_elapsed_time,
|
||||
"status": workflow_run_status,
|
||||
"version": workflow_run_version,
|
||||
"total_tokens": total_tokens,
|
||||
"file_list": file_list,
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
}
|
||||
metadata = {
|
||||
"workflow_id": workflow_id,
|
||||
"conversation_id": conversation_id,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"tenant_id": tenant_id,
|
||||
"elapsed_time": workflow_run_elapsed_time,
|
||||
"status": workflow_run_status,
|
||||
"version": workflow_run_version,
|
||||
"total_tokens": total_tokens,
|
||||
"file_list": file_list,
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
}
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
conversation_id=conversation_id,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_elapsed_time=workflow_run_elapsed_time,
|
||||
workflow_run_status=workflow_run_status,
|
||||
workflow_run_inputs=workflow_run_inputs,
|
||||
workflow_run_outputs=workflow_run_outputs,
|
||||
workflow_run_version=workflow_run_version,
|
||||
error=error,
|
||||
total_tokens=total_tokens,
|
||||
file_list=file_list,
|
||||
query=query,
|
||||
metadata=metadata,
|
||||
workflow_app_log_id=workflow_app_log_id,
|
||||
message_id=message_id,
|
||||
start_time=workflow_run.created_at,
|
||||
end_time=workflow_run.finished_at,
|
||||
)
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
conversation_id=conversation_id,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_elapsed_time=workflow_run_elapsed_time,
|
||||
workflow_run_status=workflow_run_status,
|
||||
workflow_run_inputs=workflow_run_inputs,
|
||||
workflow_run_outputs=workflow_run_outputs,
|
||||
workflow_run_version=workflow_run_version,
|
||||
error=error,
|
||||
total_tokens=total_tokens,
|
||||
file_list=file_list,
|
||||
query=query,
|
||||
metadata=metadata,
|
||||
workflow_app_log_id=workflow_app_log_id,
|
||||
message_id=message_id,
|
||||
start_time=workflow_run.created_at,
|
||||
end_time=workflow_run.finished_at,
|
||||
)
|
||||
return workflow_trace_info
|
||||
|
||||
def message_trace(self, message_id: str | None):
|
||||
@ -583,6 +619,8 @@ class TraceTask:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
streaming_metrics = self._extract_streaming_metrics(message_data)
|
||||
|
||||
metadata = {
|
||||
"conversation_id": message_data.conversation_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
@ -615,6 +653,9 @@ class TraceTask:
|
||||
metadata=metadata,
|
||||
message_file_data=message_file_data,
|
||||
conversation_mode=conversation_mode,
|
||||
gen_ai_server_time_to_first_token=streaming_metrics.get("gen_ai_server_time_to_first_token"),
|
||||
llm_streaming_time_to_generate=streaming_metrics.get("llm_streaming_time_to_generate"),
|
||||
is_streaming_request=streaming_metrics.get("is_streaming_request", False),
|
||||
)
|
||||
|
||||
return message_trace_info
|
||||
@ -840,6 +881,24 @@ class TraceTask:
|
||||
|
||||
return generate_name_trace_info
|
||||
|
||||
def _extract_streaming_metrics(self, message_data) -> dict:
|
||||
if not message_data.message_metadata:
|
||||
return {}
|
||||
|
||||
try:
|
||||
metadata = json.loads(message_data.message_metadata)
|
||||
usage = metadata.get("usage", {})
|
||||
time_to_first_token = usage.get("time_to_first_token")
|
||||
time_to_generate = usage.get("time_to_generate")
|
||||
|
||||
return {
|
||||
"gen_ai_server_time_to_first_token": time_to_first_token,
|
||||
"llm_streaming_time_to_generate": time_to_generate,
|
||||
"is_streaming_request": time_to_first_token is not None,
|
||||
}
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
return {}
|
||||
|
||||
|
||||
trace_manager_timer: threading.Timer | None = None
|
||||
trace_manager_queue: queue.Queue = queue.Queue()
|
||||
|
||||
@ -5,12 +5,18 @@ Tencent APM Trace Client - handles network operations, metrics, and API communic
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
except ImportError:
|
||||
from importlib_metadata import version # type: ignore[import-not-found]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Meter
|
||||
from opentelemetry.metrics._internal.instrument import Histogram
|
||||
@ -27,12 +33,27 @@ from opentelemetry.util.types import AttributeValue
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
from .entities.tencent_semconv import LLM_OPERATION_DURATION
|
||||
from .entities.semconv import (
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
GEN_AI_TOKEN_USAGE,
|
||||
GEN_AI_TRACE_DURATION,
|
||||
LLM_OPERATION_DURATION,
|
||||
)
|
||||
from .entities.tencent_trace_entity import SpanData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_opentelemetry_sdk_version() -> str:
|
||||
"""Get OpenTelemetry SDK version dynamically."""
|
||||
try:
|
||||
return version("opentelemetry-sdk")
|
||||
except Exception:
|
||||
logger.debug("Failed to get opentelemetry-sdk version, using default")
|
||||
return "1.27.0" # fallback version
|
||||
|
||||
|
||||
class TencentTraceClient:
|
||||
"""Tencent APM trace client using OpenTelemetry OTLP exporter"""
|
||||
|
||||
@ -57,6 +78,9 @@ class TencentTraceClient:
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
ResourceAttributes.TELEMETRY_SDK_LANGUAGE: "python",
|
||||
ResourceAttributes.TELEMETRY_SDK_NAME: "opentelemetry",
|
||||
ResourceAttributes.TELEMETRY_SDK_VERSION: _get_opentelemetry_sdk_version(),
|
||||
}
|
||||
)
|
||||
# Prepare gRPC endpoint/metadata
|
||||
@ -80,18 +104,23 @@ class TencentTraceClient:
|
||||
)
|
||||
self.tracer_provider.add_span_processor(self.span_processor)
|
||||
|
||||
self.tracer = self.tracer_provider.get_tracer("dify.tencent_apm")
|
||||
# use dify api version as tracer version
|
||||
self.tracer = self.tracer_provider.get_tracer("dify-sdk", dify_config.project.version)
|
||||
|
||||
# Store span contexts for parent-child relationships
|
||||
self.span_contexts: dict[int, trace_api.SpanContext] = {}
|
||||
|
||||
self.meter: Meter | None = None
|
||||
self.meter_provider: MeterProvider | None = None
|
||||
self.hist_llm_duration: Histogram | None = None
|
||||
self.hist_token_usage: Histogram | None = None
|
||||
self.hist_time_to_first_token: Histogram | None = None
|
||||
self.hist_time_to_generate: Histogram | None = None
|
||||
self.hist_trace_duration: Histogram | None = None
|
||||
self.metric_reader: MetricReader | None = None
|
||||
|
||||
# Metrics exporter and instruments
|
||||
try:
|
||||
from opentelemetry import metrics
|
||||
from opentelemetry.sdk.metrics import Histogram, MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader
|
||||
|
||||
@ -99,7 +128,7 @@ class TencentTraceClient:
|
||||
use_http_protobuf = protocol in {"http/protobuf", "http-protobuf"}
|
||||
use_http_json = protocol in {"http/json", "http-json"}
|
||||
|
||||
# Set preferred temporality for histograms to DELTA
|
||||
# Tencent APM works best with delta aggregation temporality
|
||||
preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA}
|
||||
|
||||
def _create_metric_exporter(exporter_cls, **kwargs):
|
||||
@ -174,23 +203,66 @@ class TencentTraceClient:
|
||||
)
|
||||
|
||||
if metric_reader is not None:
|
||||
# Use instance-level MeterProvider instead of global to support config changes
|
||||
# without worker restart. Each TencentTraceClient manages its own MeterProvider.
|
||||
provider = MeterProvider(resource=self.resource, metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(provider)
|
||||
self.meter = metrics.get_meter("dify-sdk", dify_config.project.version)
|
||||
self.meter_provider = provider
|
||||
self.meter = provider.get_meter("dify-sdk", dify_config.project.version)
|
||||
|
||||
# LLM operation duration histogram
|
||||
self.hist_llm_duration = self.meter.create_histogram(
|
||||
name=LLM_OPERATION_DURATION,
|
||||
unit="s",
|
||||
description="LLM operation duration (seconds)",
|
||||
)
|
||||
|
||||
# Token usage histogram with exponential buckets
|
||||
self.hist_token_usage = self.meter.create_histogram(
|
||||
name=GEN_AI_TOKEN_USAGE,
|
||||
unit="token",
|
||||
description="Number of tokens used in prompt and completions",
|
||||
)
|
||||
|
||||
# Time to first token histogram
|
||||
self.hist_time_to_first_token = self.meter.create_histogram(
|
||||
name=GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
unit="s",
|
||||
description="Time to first token for streaming LLM responses (seconds)",
|
||||
)
|
||||
|
||||
# Time to generate histogram
|
||||
self.hist_time_to_generate = self.meter.create_histogram(
|
||||
name=GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
unit="s",
|
||||
description="Total time to generate streaming LLM responses (seconds)",
|
||||
)
|
||||
|
||||
# Trace duration histogram
|
||||
self.hist_trace_duration = self.meter.create_histogram(
|
||||
name=GEN_AI_TRACE_DURATION,
|
||||
unit="s",
|
||||
description="End-to-end GenAI trace duration (seconds)",
|
||||
)
|
||||
|
||||
self.metric_reader = metric_reader
|
||||
else:
|
||||
self.meter = None
|
||||
self.meter_provider = None
|
||||
self.hist_llm_duration = None
|
||||
self.hist_token_usage = None
|
||||
self.hist_time_to_first_token = None
|
||||
self.hist_time_to_generate = None
|
||||
self.hist_trace_duration = None
|
||||
self.metric_reader = None
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Metrics initialization failed; metrics disabled")
|
||||
self.meter = None
|
||||
self.meter_provider = None
|
||||
self.hist_llm_duration = None
|
||||
self.hist_token_usage = None
|
||||
self.hist_time_to_first_token = None
|
||||
self.hist_time_to_generate = None
|
||||
self.hist_trace_duration = None
|
||||
self.metric_reader = None
|
||||
|
||||
def add_span(self, span_data: SpanData) -> None:
|
||||
@ -212,10 +284,158 @@ class TencentTraceClient:
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
LLM_OPERATION_DURATION,
|
||||
latency_seconds,
|
||||
json.dumps(attrs, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_llm_duration.record(latency_seconds, attrs) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record LLM duration", exc_info=True)
|
||||
|
||||
def record_token_usage(
|
||||
self,
|
||||
token_count: int,
|
||||
token_type: str,
|
||||
operation_name: str,
|
||||
request_model: str,
|
||||
response_model: str,
|
||||
server_address: str,
|
||||
provider: str,
|
||||
) -> None:
|
||||
"""Record token usage histogram.
|
||||
|
||||
Args:
|
||||
token_count: Number of tokens used
|
||||
token_type: "input" or "output"
|
||||
operation_name: Operation name (e.g., "chat")
|
||||
request_model: Model used in request
|
||||
response_model: Model used in response
|
||||
server_address: Server address
|
||||
provider: Model provider name
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_token_usage") or self.hist_token_usage is None:
|
||||
return
|
||||
|
||||
attributes = {
|
||||
"gen_ai.operation.name": operation_name,
|
||||
"gen_ai.request.model": request_model,
|
||||
"gen_ai.response.model": response_model,
|
||||
"gen_ai.system": provider,
|
||||
"gen_ai.token.type": token_type,
|
||||
"server.address": server_address,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %d | Attributes: %s",
|
||||
GEN_AI_TOKEN_USAGE,
|
||||
token_count,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_token_usage.record(token_count, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record token usage", exc_info=True)
|
||||
|
||||
def record_time_to_first_token(
|
||||
self, ttft_seconds: float, provider: str, model: str, operation_name: str = "chat"
|
||||
) -> None:
|
||||
"""Record time to first token histogram.
|
||||
|
||||
Args:
|
||||
ttft_seconds: Time to first token in seconds
|
||||
provider: Model provider name
|
||||
model: Model name
|
||||
operation_name: Operation name (default: "chat")
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_time_to_first_token") or self.hist_time_to_first_token is None:
|
||||
return
|
||||
|
||||
attributes = {
|
||||
"gen_ai.operation.name": operation_name,
|
||||
"gen_ai.system": provider,
|
||||
"gen_ai.request.model": model,
|
||||
"gen_ai.response.model": model,
|
||||
"stream": "true",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
|
||||
ttft_seconds,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_time_to_first_token.record(ttft_seconds, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record time to first token", exc_info=True)
|
||||
|
||||
def record_time_to_generate(
|
||||
self, ttg_seconds: float, provider: str, model: str, operation_name: str = "chat"
|
||||
) -> None:
|
||||
"""Record time to generate histogram.
|
||||
|
||||
Args:
|
||||
ttg_seconds: Time to generate in seconds
|
||||
provider: Model provider name
|
||||
model: Model name
|
||||
operation_name: Operation name (default: "chat")
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_time_to_generate") or self.hist_time_to_generate is None:
|
||||
return
|
||||
|
||||
attributes = {
|
||||
"gen_ai.operation.name": operation_name,
|
||||
"gen_ai.system": provider,
|
||||
"gen_ai.request.model": model,
|
||||
"gen_ai.response.model": model,
|
||||
"stream": "true",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE,
|
||||
ttg_seconds,
|
||||
json.dumps(attributes, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_time_to_generate.record(ttg_seconds, attributes) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record time to generate", exc_info=True)
|
||||
|
||||
def record_trace_duration(self, duration_seconds: float, attributes: dict[str, str] | None = None) -> None:
|
||||
"""Record end-to-end trace duration histogram in seconds.
|
||||
|
||||
Args:
|
||||
duration_seconds: Trace duration in seconds
|
||||
attributes: Optional attributes (e.g., conversation_mode, app_id)
|
||||
"""
|
||||
try:
|
||||
if not hasattr(self, "hist_trace_duration") or self.hist_trace_duration is None:
|
||||
return
|
||||
|
||||
attrs: dict[str, str] = {}
|
||||
if attributes:
|
||||
for k, v in attributes.items():
|
||||
attrs[k] = str(v) if not isinstance(v, (str, int, float, bool)) else v # type: ignore[assignment]
|
||||
|
||||
logger.info(
|
||||
"[Tencent Metrics] Metric: %s | Value: %.4f | Attributes: %s",
|
||||
GEN_AI_TRACE_DURATION,
|
||||
duration_seconds,
|
||||
json.dumps(attrs, ensure_ascii=False),
|
||||
)
|
||||
|
||||
self.hist_trace_duration.record(duration_seconds, attrs) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Failed to record trace duration", exc_info=True)
|
||||
|
||||
def _create_and_export_span(self, span_data: SpanData) -> None:
|
||||
"""Create span using OpenTelemetry Tracer API"""
|
||||
try:
|
||||
@ -296,11 +516,19 @@ class TencentTraceClient:
|
||||
|
||||
if self.tracer_provider:
|
||||
self.tracer_provider.shutdown()
|
||||
|
||||
# Shutdown instance-level meter provider
|
||||
if self.meter_provider is not None:
|
||||
try:
|
||||
self.meter_provider.shutdown() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
logger.debug("[Tencent APM] Error shutting down meter provider", exc_info=True)
|
||||
|
||||
if self.metric_reader is not None:
|
||||
try:
|
||||
self.metric_reader.shutdown() # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("[Tencent APM] Error shutting down metric reader", exc_info=True)
|
||||
|
||||
except Exception:
|
||||
logger.exception("[Tencent APM] Error during client shutdown")
|
||||
|
||||
@ -47,6 +47,9 @@ GEN_AI_COMPLETION = "gen_ai.completion"
|
||||
|
||||
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
|
||||
# Streaming Span Attributes
|
||||
GEN_AI_IS_STREAMING_REQUEST = "llm.is_streaming" # Same as OpenLLMetry semconv
|
||||
|
||||
# Tool
|
||||
TOOL_NAME = "tool.name"
|
||||
|
||||
@ -62,6 +65,19 @@ INSTRUMENTATION_LANGUAGE = "python"
|
||||
|
||||
# Metrics
|
||||
LLM_OPERATION_DURATION = "gen_ai.client.operation.duration"
|
||||
GEN_AI_TOKEN_USAGE = "gen_ai.client.token.usage"
|
||||
GEN_AI_SERVER_TIME_TO_FIRST_TOKEN = "gen_ai.server.time_to_first_token"
|
||||
GEN_AI_STREAMING_TIME_TO_GENERATE = "gen_ai.streaming.time_to_generate"
|
||||
# The LLM trace duration which is exclusive to tencent apm
|
||||
GEN_AI_TRACE_DURATION = "gen_ai.trace.duration"
|
||||
|
||||
# Token Usage Attributes
|
||||
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
|
||||
GEN_AI_REQUEST_MODEL = "gen_ai.request.model"
|
||||
GEN_AI_RESPONSE_MODEL = "gen_ai.response.model"
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
GEN_AI_TOKEN_TYPE = "gen_ai.token.type"
|
||||
SERVER_ADDRESS = "server.address"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
@ -14,10 +14,11 @@ from core.ops.entities.trace_entity import (
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.tencent_trace.entities.tencent_semconv import (
|
||||
from core.ops.tencent_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_IS_ENTRY,
|
||||
GEN_AI_IS_STREAMING_REQUEST,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROVIDER,
|
||||
@ -156,6 +157,25 @@ class TencentSpanBuilder:
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
}
|
||||
|
||||
if usage_data.get("time_to_first_token") is not None:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
@ -163,21 +183,7 @@ class TencentSpanBuilder:
|
||||
name="GENERATION",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.created_at),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.GENERATION.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
GEN_AI_PROVIDER: process_data.get("model_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
},
|
||||
attributes=attributes,
|
||||
status=TencentSpanBuilder._get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
@ -191,6 +197,19 @@ class TencentSpanBuilder:
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
attributes = {
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_IS_ENTRY: "true",
|
||||
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||
OUTPUT_VALUE: str(trace_info.outputs or ""),
|
||||
}
|
||||
|
||||
if trace_info.is_streaming_request:
|
||||
attributes[GEN_AI_IS_STREAMING_REQUEST] = "true"
|
||||
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
@ -198,15 +217,7 @@ class TencentSpanBuilder:
|
||||
name="message",
|
||||
start_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.start_time),
|
||||
end_time=TencentSpanBuilder._get_time_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.WORKFLOW.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_IS_ENTRY: "true",
|
||||
INPUT_VALUE: str(trace_info.inputs or ""),
|
||||
OUTPUT_VALUE: str(trace_info.outputs or ""),
|
||||
},
|
||||
attributes=attributes,
|
||||
status=status,
|
||||
links=links,
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user