Compare commits

...

266 Commits

Author SHA1 Message Date
cc63c8499f bump version to 0.3.26 (#1307)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 16:11:24 +08:00
f191b8b8d1 milvus docker compose env (#1306)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 16:05:37 +08:00
5003db987d milvus secure check fix (#1305)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 13:11:06 +08:00
07aab5e868 Feat/add milvus vector db (#1302)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-10 21:56:24 +08:00
875dfbbf0e fix: openllm completion start with prompt, remove it (#1303) 2023-10-10 04:44:19 -05:00
9e7efa45d4 document segmentApi Add get&update&delete operate (#1285)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-10-10 13:27:06 +08:00
8bf892b306 feat: bump version to 0.3.25 (#1300) 2023-10-10 13:03:49 +08:00
8480b0197b fix: prompt for baichuan text generation models (#1299) 2023-10-10 13:01:18 +08:00
df07fb5951 feat: provider add baichuan (#1298) 2023-10-09 23:10:43 -05:00
4ab4bcc074 feat: support openllm embedding (#1293) 2023-10-09 23:09:35 -05:00
1d4f019de4 feat: add baichuan llm support (#1294)
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
2023-10-09 23:09:26 -05:00
677aacc8e3 feat: upgrade xinference client to 0.5.2 (#1292) 2023-10-09 08:12:58 -05:00
fda937175d feat: qdrant support in docker compose (#1286) 2023-10-08 12:04:04 -05:00
024250803a feat: move login_required wrapper outside (#1281) 2023-10-08 05:21:32 -05:00
b711ce33b7 Application share qrcode (#1277)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
2023-10-08 09:34:49 +08:00
52bec63275 chore(web): strong type (#1259) 2023-10-07 04:42:16 -05:00
657fa80f4d fix devcontainer issue (#1273) 2023-10-07 10:34:25 +08:00
373e90ee6d fix: detached model in completion thread (#1269) 2023-10-02 22:27:25 +08:00
41d4c5b424 fix: count down thread in completion db not commit (#1267) 2023-10-02 10:19:26 +08:00
86a9dea428 fix: db not commit when streaming output (#1266) 2023-10-01 16:41:52 +08:00
8606d80c66 fix: request timeout when openai completion (#1265) 2023-10-01 16:00:23 +08:00
5bffa1d918 feat: bump version to 0.3.24 (#1262) 2023-09-28 18:32:06 +08:00
c9b0fe47bf Fix/notion sync (#1258) 2023-09-28 14:39:13 +08:00
bcd744b6b7 fix: doc (#1256) 2023-09-28 11:26:04 +08:00
5e511e01bf Fix/dataset api key delete (#1255)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-28 10:41:41 +08:00
52291c645e fix: dataset footer styles (#1254) 2023-09-28 10:06:52 +08:00
a31466d34e fix: db session not commit before long llm call running (#1251) 2023-09-27 21:40:26 +08:00
d38eac959b fix: wenxin model name invalid when llm call (#1248) 2023-09-27 16:29:13 +08:00
9dbb8acd4b Feat/dataset support api service (#1240)
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
2023-09-27 16:06:49 +08:00
46154c6705 Feat/dataset service api (#1245)
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-09-27 16:06:32 +08:00
54ff03c35d fix: dataset query error. (#1244) 2023-09-27 15:24:54 +08:00
18c710c906 feat: support binding context var (#1227)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-09-27 14:53:22 +08:00
59236b789f Fix: dataset list refresh (#1216) 2023-09-27 10:31:46 +08:00
fd3d43cae1 Fix: debounce of dataset creation (#1237) 2023-09-27 10:31:27 +08:00
8eae643911 Fix App logs page modal show different model icon. (#1224) 2023-09-27 08:54:52 +08:00
fd9413874a fix: FATAL: role "root" does not exist. (#1233) 2023-09-26 10:20:00 +08:00
227f9fb77d Feat/api jwt (#1212) 2023-09-25 12:49:16 +08:00
c40ee7e629 feat: batch run support retry errors and decrease rate limit times (#1215) 2023-09-25 10:20:50 +08:00
841e967d48 Fix: add loading for dataset creation (#1214) 2023-09-24 01:35:20 -05:00
9df0dcedae fix: dataset eslint error (#1221) 2023-09-22 22:38:33 +08:00
724e053732 Fix/qdrant data issue (#1203)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-22 14:21:26 +08:00
e409895c02 Feat/huggingface embedding support (#1211)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-09-22 13:59:02 +08:00
32d9b6181c fix: transaction not commit during long LLM calls (#1213) 2023-09-22 12:43:06 +08:00
2b018fade2 fix: transaction hangs due to message commit block during long LLM calls (#1206) 2023-09-21 11:22:10 +08:00
e65f9cb17a Complete type defined. (#1200) 2023-09-19 23:27:06 -05:00
1367f34398 fix: provider spark free quota text (#1201) 2023-09-20 11:46:25 +08:00
e47f6b879a add help wanted issue template (#1199) 2023-09-19 20:02:41 -05:00
5809edd74b feat: bump version to 0.3.23 (#1198) 2023-09-20 00:14:36 +08:00
05bfa11915 build: update devDependencies (#1125) 2023-09-19 13:31:48 +08:00
435f804c6f fix: gpt-3.5-turbo-instruct context size to 8192 (#1196) 2023-09-19 02:10:22 +08:00
ae3f1ac0a9 feat: support gpt-3.5-turbo-instruct model (#1195) 2023-09-19 02:05:04 +08:00
269a465fc4 Feat/improve vector database logic (#1193)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-18 18:15:41 +08:00
60e0bbd713 Feat/provider add zhipuai (#1192)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-09-18 18:02:05 +08:00
827c97f0d3 feat: add zhipuai (#1188) 2023-09-18 17:32:31 +08:00
c8bd76cd66 fix: inference embedding validate (#1187) 2023-09-16 03:09:36 +08:00
ec5f585df4 1111 wrong embedding model displayed in datasets (#1186) 2023-09-15 07:54:45 -05:00
1de48f33ca feat(web): service request return generics type (#1157) 2023-09-15 07:54:20 -05:00
6b41a9593e fix: text error (#1184) 2023-09-15 14:15:28 +08:00
82267083e8 fix: model param description error (#1183) 2023-09-15 11:36:01 +08:00
c385961d33 chore: Optimization model parameter description (#1181) 2023-09-15 11:14:14 +08:00
20bab6edec Restore the application template (#1174)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-09-14 08:28:32 -05:00
67bed54f32 Mermaid front end rendering (#1166)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-09-14 14:09:23 +08:00
leo
562a571281 fix: Improved fallback solution for avatar image loading failure (#1172) 2023-09-14 13:31:35 +08:00
fc68c81791 fix: correct invite url (#1173) 2023-09-14 12:07:34 +08:00
5d9070bc60 Feat/add blocking mode resource return (#1171)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-13 18:53:35 +08:00
b11fb0dfd1 fix LocalAI is missing in lang/en (#1169) 2023-09-13 10:08:33 +08:00
d1c5c5f160 add video to cn readme (#1165) 2023-09-12 08:30:12 -05:00
0b1d1440aa Update README.md (#1164) 2023-09-12 07:48:35 -05:00
0c420d64b3 chore: hover conversation show option button (#1160) 2023-09-12 16:35:13 +08:00
f9082104ed feat: add hosted moderation (#1158) 2023-09-12 10:26:12 +08:00
983834cd52 feat: spark check (#1134) 2023-09-11 17:31:03 +08:00
96d10c8b39 feat: spark free quota verify (#1152) 2023-09-11 17:30:54 +08:00
24cb992843 feat: bump version to 0.3.22 (#1153) 2023-09-11 12:04:06 +08:00
7907c0bf58 Update bug_report.yml (#1151) 2023-09-11 10:48:37 +08:00
ebf4fd9a09 Update issue template (#1150) 2023-09-11 10:45:10 +08:00
38b9901274 fix(web): complete some ts type (#1148) 2023-09-11 09:30:17 +08:00
642842d61b Feat:dataset retiever resource (#1123)
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-09-10 15:17:43 +08:00
e161c511af Feat:csv & docx support (#1139)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-10 15:17:22 +08:00
f29e82685e feat: bump version to 0.3.21 (#1145) 2023-09-10 12:34:54 +08:00
3a5ae96e7b fix: TRANSFORMERS_OFFLINE orders in Dockerfile (#1144) 2023-09-10 12:26:13 +08:00
b63a685386 feat: set transformers offline default true (#1143) 2023-09-10 12:20:58 +08:00
877da82b06 feat: cache huggingface gpt2 tokenizer files (#1138) 2023-09-10 12:16:21 +08:00
6637629045 fix: remove the deprecated depends_on.condition format (#1142) 2023-09-10 12:07:20 +08:00
e925b6c572 fix: log page compatible old query (#1141) 2023-09-10 11:29:25 +08:00
5412f4aba5 fix: in log page not show user query (#1140) 2023-09-10 09:30:30 +08:00
2d5ad0d208 feat: support optional query content (#1097)
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
2023-09-10 00:12:34 +08:00
1ade70aa1e feat: bump version to 0.3.20 (#1135) 2023-09-09 23:47:14 +08:00
2658c4d57b fix: answer returned null when response_mode was blocking (#1133) 2023-09-09 23:22:21 +08:00
84c76bc04a Feat/chat add origin (#1130) 2023-09-09 19:17:12 +08:00
6effcd3755 feat: optimize celery start cmd (#1129) 2023-09-09 13:48:29 +08:00
d9866489f0 feat: add health check and depend condition in docker compose (#1113) 2023-09-09 13:47:08 +08:00
c4d8bdc3db fix: hf hosted inference check (#1128) 2023-09-09 00:29:48 +08:00
681eb1cfcc fix: click inner link no jump (#1118) 2023-09-08 10:21:42 +08:00
a5d21f3b09 fix: shortening invite url (#1100)
Co-authored-by: MatriQi <matri@aifi.io>
2023-09-07 17:15:57 +08:00
7ba068c3e4 fix: self host embedding missing base url config (#1116) 2023-09-07 14:56:38 +08:00
b201eeedbd fix: optimize styles (#1112) 2023-09-07 14:24:09 +08:00
f28cb84977 fix(web): fix AppCard Menu popover open bug (#1107) 2023-09-07 09:47:31 +08:00
714872cd58 chore: enchancment frontend readme (#1110) 2023-09-07 09:43:24 +08:00
0708bd60ee fix: try to fix chunk load error (#1109) 2023-09-06 15:47:53 +08:00
23a6c85b80 chore: handle workspace apps scrollbar (#1101) 2023-09-05 15:56:21 +08:00
4a28599fbd fix: optimize feedback and app icon (#1099) 2023-09-05 09:13:59 +08:00
7c66d3c793 feat: Optimize the description for Azure deployment name (#1091) 2023-09-04 14:26:22 +08:00
cc9edfffd8 fix: markdown code lang capitalization and line number color (#1098) 2023-09-04 11:31:25 +08:00
6fa2454c9a fix: change frontend start script (#1096) 2023-09-04 11:10:32 +08:00
487e699021 fix: ui in chat openning statement (#1094) 2023-09-04 10:26:46 +08:00
a7cdb745c1 feat: support spark v2 validate (#1086) 2023-09-01 20:53:32 +08:00
73c86ee6a0 fix: prompt of title generation (#1084) 2023-09-01 14:55:58 +08:00
48eb590065 feat: optimize last_active_at update (#1083) 2023-09-01 13:58:26 +08:00
33562a9d8d feat: optimize prompt (#1080) 2023-09-01 11:46:06 +08:00
c9194ba382 chore(api): api image multistage build (#1069) 2023-09-01 11:13:22 +08:00
a199fa6388 feat: optimize high load sql query of document segment (#1078) 2023-09-01 10:52:39 +08:00
4c8608dc61 feat: optimize conversation title generation output must be a valid JSON (#1077) 2023-09-01 10:31:42 +08:00
a6b0f788e7 feat: add visual studio code debug config. (#1068)
Co-authored-by: Keruberosu <631677014@qq.com>
2023-09-01 09:15:06 +08:00
df6604a734 feat: optimize generation of conversation title (#1075) 2023-09-01 02:28:37 +08:00
1ca86cf9ce feat: bump version to 0.3.19 (#1074) 2023-08-31 21:42:58 +08:00
78e26f8b75 fix: summary no docs (#1073) 2023-08-31 20:19:26 +08:00
2191312bb9 fix: segments query missing idx hit (#1072) 2023-08-31 19:39:44 +08:00
fcc6b41ab7 feat: decrease claude model request time by set max top_k to 10 (#1071) 2023-08-31 18:23:44 +08:00
9458b8978f feat: siderbar operation support portal (#1061) 2023-08-31 17:46:51 +08:00
d75e8aeafa feat: disable anthropic retry (#1067) 2023-08-31 16:44:46 +08:00
2eba98a465 feat: optimize anthropic connection pool (#1066) 2023-08-31 16:18:59 +08:00
a7a7aab7a0 fix: csv import error (#1063) 2023-08-31 15:42:28 +08:00
86bfbb47d5 chore: doc issue (#1062) 2023-08-31 14:54:16 +08:00
d33a269548 refactor(file extractor): file extractor (#1059) 2023-08-31 14:45:31 +08:00
d3f8ea2df0 Feat/support to invite multiple users (#1011) 2023-08-31 01:18:31 +08:00
7df56ed617 fix error weaviate vector (#1058)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-30 20:34:17 +08:00
e34dcc0406 feat: code support copy (#1057) 2023-08-30 18:08:47 +08:00
a834ba8759 feat: support rename conversation (#1056) 2023-08-30 17:32:32 +08:00
c67f345d0e Fix: disable operations of dataset when embedding unavailable (#1055)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-30 17:27:19 +08:00
8b8e510bfe fix: handle AttributeError for datasets and index (#1052) 2023-08-30 11:14:16 +08:00
3db839a5cb 773 change embed title welcome to use (#1053) 2023-08-30 11:03:25 +08:00
417c19577a feat: add LocalAI local embedding model support (#1021)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-08-29 22:22:02 +08:00
b5953039de recreate qdrant vector (#1049)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-29 15:00:36 +08:00
a43e80dd9c add qdrant migration (#1046)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-29 10:37:04 +08:00
ad5f27bc5f fix openpyxl dimensions error (#1041) 2023-08-29 10:36:48 +08:00
05e0985f29 chore: match new dataset tool format (#1044) 2023-08-29 09:07:45 +08:00
7b3314c5db fix: dataset desc (#1045) 2023-08-29 09:07:27 +08:00
a55ba6e614 Fix/ignore economy dataset (#1043)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-29 03:37:45 +08:00
f9bec1edf8 chore: perfect type definition (#1003) 2023-08-28 19:48:53 +08:00
16199e968e fix notion import limit check (#1042)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-28 16:49:03 +08:00
02452421d5 fix: pub generate message text return null (#1037) 2023-08-28 16:43:54 +08:00
3a5c7c75ad Fix/model selector (#1032) 2023-08-28 10:54:41 +08:00
a7415ecfd8 Fix/upload document limit (#1033) 2023-08-28 10:53:45 +08:00
934def5fcc Fix: eslint (#1030) 2023-08-27 17:06:16 +08:00
0796791de5 feat: hf inference endpoint stream support (#1028) 2023-08-26 19:48:34 +08:00
6c148b223d fix: dataset query truncated (#1026) 2023-08-26 17:35:17 +08:00
4b168f4838 fix: maintenance notice (#1025) 2023-08-26 16:09:55 +08:00
1c114eaef3 feat: update contributing (#1020) 2023-08-25 21:19:13 +08:00
e053215155 fix document estimate parameter (#1019)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 20:10:08 +08:00
13482b0fc1 feat: maintenance notice (#1016) 2023-08-25 19:38:52 +08:00
38fa152cc4 fix update document index technique (#1018)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 18:29:55 +08:00
2d9616c29c fix: xinference last token being ignored (#1013) 2023-08-25 18:15:05 +08:00
915e26527b update dataset index struct (#1012)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 15:52:33 +08:00
2d604d9330 Fix/filter empty segment (#1004)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 15:50:29 +08:00
e7199826cc embedding model available check (#1009)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-25 00:25:16 +08:00
70e24b7594 fix: loading and calc rem (#1006) 2023-08-24 23:24:33 +08:00
c1602aafc7 refactor:cache in place & function name (#1001) 2023-08-24 22:54:21 +08:00
a3fec11438 fix: styles (#1005) 2023-08-24 22:37:46 +08:00
b1fd1b3ab3 Feat/vector db manage (#997)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-24 21:27:31 +08:00
5397799aac document limit (#999)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-24 21:27:13 +08:00
8e837dde1a feat: bump version to 0.3.18 (#1000) 2023-08-24 18:13:18 +08:00
9ae91a2ec3 feat: optimize xinference request max token key and stop reason (#998) 2023-08-24 18:11:15 +08:00
276d3d10a0 fix: apps loading issue (#994) 2023-08-24 17:57:38 +08:00
f13623184a fix style in app share (#995) 2023-08-24 17:57:25 +08:00
ef61e1487f fix: safetensor arm complie error (#996) 2023-08-24 17:38:10 +08:00
701e2b334f feat: remove unnecessary prompt of baichuan (#993) 2023-08-24 15:30:59 +08:00
6ebd6e7890 feat: bump version to 0.3.17 (#992) 2023-08-24 15:12:47 +08:00
bd3a9b2f8d fix: xinference-chat-stream-response (#991) 2023-08-24 14:39:34 +08:00
18d3877151 feat: optimize xinference stream (#989) 2023-08-24 13:58:34 +08:00
53e83d8697 feat: optimize baichuan prompt (#988) 2023-08-24 12:07:10 +08:00
6377fc75c6 chore: update lintrc config (#986) 2023-08-24 11:46:59 +08:00
2c30d19cbe feat: add baichuan prompt (#985) 2023-08-24 10:22:36 +08:00
9b247fccd4 feat: adjust hf max tokens (#979) 2023-08-23 22:24:50 +08:00
3d38aa7138 feat: bump version to 0.3.16 2023-08-23 20:16:54 +08:00
7d2552b3f2 feat: upgrade xinference to 0.2.1 which support stream response (#977) 2023-08-23 20:15:45 +08:00
117a209ad4 Fix:condition for dataset availability check (#973) 2023-08-23 19:57:27 +08:00
071e7800a0 fix: add hf task field (#976)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-08-23 19:48:31 +08:00
a76fde3d23 feat: optimize hf inference endpoint (#975) 2023-08-23 19:47:50 +08:00
1fc57d7358 normalize embedding (#974)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-23 19:10:11 +08:00
916d8be0ae fix: activation page reload issue after activating (#964) 2023-08-23 13:54:40 +08:00
a38412de7b update doc (#965) 2023-08-23 12:29:52 +08:00
9c9f0ddb93 fix: user activation request 404 issue (#963) 2023-08-23 08:57:25 +08:00
f8fbe96da4 feat: bump version to 0.3.15 (#959) 2023-08-22 18:20:33 +08:00
215a27fd95 Feat/add xinference openllm provider (#958) 2023-08-22 18:19:10 +08:00
5cba2e7087 fix: web reader tool retrieve content empty (#957) 2023-08-22 18:01:16 +08:00
5623839c71 update document segment (#950)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-22 17:59:24 +08:00
78d3aa5fcd fix: embedding init err (#956) 2023-08-22 17:43:59 +08:00
a7c78d2cd2 fix: spark provider field name (#955) 2023-08-22 17:28:18 +08:00
4db35fa375 chore: obsolete info api use new api (#954) 2023-08-22 16:59:57 +08:00
e67a1413b6 chore: create btn to first place (#953) 2023-08-22 16:20:56 +08:00
4f3053a8cc fix: xinference chat completion error (#952) 2023-08-22 15:58:04 +08:00
b3c2bf125f Feat/model providers (#951) 2023-08-22 15:38:12 +08:00
9d5299e9ec fix: segment error tip & save segment disable when loading (#949) 2023-08-22 15:22:16 +08:00
aee15adf1b update document segment (#948)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-22 15:19:09 +08:00
b185a70c21 Fix/speech to text button (#947) 2023-08-22 14:55:20 +08:00
a3aba7a9aa fix: provider model not delete when reset key pair (#946) 2023-08-22 13:48:58 +08:00
866ee5da91 fix: openllm generate cutoff (#945) 2023-08-22 13:43:36 +08:00
e8039a7da8 fix: add flex-wrap to categories container (#944) 2023-08-22 13:39:52 +08:00
5e0540077a chore: perfect type definition (#940) 2023-08-22 10:58:06 +08:00
b346bd9b83 fix: default language improvement in activation page (#942) 2023-08-22 09:28:37 +08:00
062e2e915b fix: login improvement (#941) 2023-08-21 21:26:32 +08:00
e0a48c4972 fix: xinference chat support (#939) 2023-08-21 20:44:29 +08:00
f53242c081 Feat/add document status tooltip (#937) 2023-08-21 18:07:51 +08:00
4b53bb1a32 Feat/token support (#909)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: jyong <jyong@dify.ai>
2023-08-21 13:57:18 +08:00
4c49ecedb5 feat: optimize web reader summary in 3.5 (#933) 2023-08-21 11:58:01 +08:00
4ff1870a4b fix: web reader tool missing nodejs (#932) 2023-08-21 11:26:11 +08:00
6c832ee328 fix: remove openllm pypi package because of this package too large (#931) 2023-08-21 02:12:28 +08:00
25264e7852 feat: add xinference embedding model support (#930) 2023-08-20 19:35:07 +08:00
18dd0d569d fix: xinference max_tokens alisa error (#929) 2023-08-20 19:12:52 +08:00
3ea8d7a019 feat: add openllm support (#928) 2023-08-20 19:04:33 +08:00
da3f10a55e feat: server xinference support (#927) 2023-08-20 17:46:41 +08:00
8c991b5b26 Fix Readme.md typo error. (#926) 2023-08-20 12:02:04 +08:00
22c1aafb9b fix: document paused at format error (#925) 2023-08-20 01:54:12 +08:00
8d6d1c442b feat: optimize generate name length (#924) 2023-08-19 23:34:38 +08:00
95b179fb39 fix: replicate text generation model validate (#923) 2023-08-19 21:40:42 +08:00
3a0a9e2d8f fix: embedding get price definition missing (#922) 2023-08-19 21:31:40 +08:00
0a0d63457d feat: record price unit in messages (#919) 2023-08-19 18:51:40 +08:00
920fb6d0e1 fix: embedding price config (#918) 2023-08-19 16:54:08 +08:00
fd0fc8f4fe Fix/price calc (#862) 2023-08-19 16:41:35 +08:00
1c552ff23a fix: azure embedding model credentials include base_model_name is invalid for openai sdk (#917) 2023-08-19 16:24:18 +08:00
5163dd38e5 fix: run extra model serval ex not return (#916) 2023-08-19 14:35:16 +08:00
2a27dad2fb fix: run model serval ex not return (#915) 2023-08-19 14:16:41 +08:00
930f74c610 feat: remove unuse envs (#912) 2023-08-18 21:34:28 +08:00
3f250c9e12 Update README_CN.md 2023-08-18 20:39:40 +08:00
fa408d264c Update README.md 2023-08-18 20:38:52 +08:00
09ea27f1ee feat: optimize service api authorization header invalid error (#910) 2023-08-18 20:32:44 +08:00
db7156dafd Feature/mutil embedding model (#908)
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-08-18 17:37:31 +08:00
4420281d96 Feat/segment add tag (#907) 2023-08-18 17:18:58 +08:00
d9afebe216 feat: optimize output parse (#906) 2023-08-18 17:00:40 +08:00
1d9cc5ca05 fix: universal chat when default model invalid (#905) 2023-08-18 16:20:42 +08:00
edb06f6aed fix: react router agent direct output (#904) 2023-08-18 14:31:20 +08:00
6ca3bcbcfd fix: sensitive_word_avoidance npe (#902) 2023-08-18 11:43:56 +08:00
71a9d63232 fix entrypoint script line endings (#900) 2023-08-18 10:42:44 +08:00
fb62017e50 Fix/embedding chat (#899)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-08-18 10:39:05 +08:00
9adbeadeec feat: claude paid optimize (#890) 2023-08-17 16:56:20 +08:00
2f7b234cc5 fix: max token not exist in generate summary when calc rest tokens (#891) 2023-08-17 16:33:32 +08:00
4f5f9506ab Feat/pay modal (#889) 2023-08-17 15:49:22 +08:00
0cc0b6e052 fix: error raise status code not exist (#888) 2023-08-17 15:33:35 +08:00
cd78adb0ab feat: support show model display name (#887) 2023-08-17 15:13:35 +08:00
f42e7d1a61 feat: add spark v2 support (#885) 2023-08-17 15:08:57 +08:00
c4d759dfba fix: wenxin error not raise when stream mode (#884) 2023-08-17 13:40:00 +08:00
a58f95fa91 fix: web dockfile (#883) 2023-08-17 13:07:07 +08:00
39574dcf6b feat: optimize prompt of suggested_questions_after_answer (#881) 2023-08-17 10:46:33 +08:00
5b06ded0b1 build: improve dockerfile (#851)
Co-authored-by: MatriQi <matri@aifi.io>
2023-08-17 10:25:11 +08:00
155a4733f6 Feat/customizable file upload config (#818) 2023-08-16 23:14:27 +08:00
b7c29ea1b6 feat: optimize model when app create (#875) 2023-08-16 22:29:18 +08:00
cc2d71c253 feat: optimize override app model config convert (#874) 2023-08-16 20:48:42 +08:00
cd11613952 Update README.md (#865) 2023-08-16 19:26:35 +08:00
e0d6d00a87 Update README_CN.md (#867) 2023-08-16 19:26:11 +08:00
2dfb3e95f6 feat: optimize error record in agent (#869) 2023-08-16 15:55:42 +08:00
f207e180df fix multi thread app context (#868)
Co-authored-by: jyong <jyong@dify.ai>
2023-08-16 15:39:31 +08:00
948d64bbef fix: get_num_tokens_from_messages params error (#866) 2023-08-16 14:58:44 +08:00
01e912e543 fix: promptEng menu in wrong place (#864) 2023-08-16 14:56:17 +08:00
f95f6db0e3 feat: support app rename and make app card ui better (#766)
Co-authored-by: Gillian97 <jinling.sunshine@gmail.com>
2023-08-16 10:31:08 +08:00
216fc5d312 feat: bump version 0.3.14 (#861) 2023-08-15 22:46:15 +08:00
7a8590980e fix: dataset direct output (#860) 2023-08-15 22:27:31 +08:00
e8c14bb732 feat: rename title in site both rename name in app (#857) 2023-08-15 20:42:32 +08:00
bf45f08e78 chore: handle provider name capitalization (#855) 2023-08-15 17:22:40 +08:00
2c77a74c40 fix: frontend permission check (#784) 2023-08-15 13:35:47 +08:00
440cf63317 fix: setting modal margin (#849) 2023-08-15 12:05:27 +08:00
50b11e925b fix: change config string variable limit (#837)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2023-08-15 11:26:58 +08:00
7cc81b4269 fix: var config content can not be saved (#841) 2023-08-15 09:51:43 +08:00
93b0813b73 Update README.md (#839) 2023-08-15 09:43:21 +08:00
649b44aefa Update README_CN.md (#840) 2023-08-15 09:43:11 +08:00
1e95d74ae2 update doc (#838) 2023-08-15 09:25:37 +08:00
700d5f2673 update llms (#835) 2023-08-14 22:41:40 +08:00
805 changed files with 34381 additions and 5853 deletions

View File

@ -1,11 +1,8 @@
FROM mcr.microsoft.com/devcontainers/anaconda:0-3 FROM mcr.microsoft.com/devcontainers/python:3.10
COPY . .
# Copy environment.yml (if found) to a temp location so we update the environment. Also
# copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/
RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \
&& rm -rf /tmp/conda-tmp
# [Optional] Uncomment this section to install additional OS packages. # [Optional] Uncomment this section to install additional OS packages.
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
# && apt-get -y install --no-install-recommends <your-package-list-here> # && apt-get -y install --no-install-recommends <your-package-list-here>

View File

@ -1,13 +1,12 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the // For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda // README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
{ {
"name": "Anaconda (Python 3)", "name": "Python 3.10",
"build": { "build": {
"context": "..", "context": "..",
"dockerfile": "Dockerfile" "dockerfile": "Dockerfile"
}, },
"features": { "features": {
"ghcr.io/dhoeric/features/act:1": {},
"ghcr.io/devcontainers/features/node:1": { "ghcr.io/devcontainers/features/node:1": {
"nodeGypDependencies": true, "nodeGypDependencies": true,
"version": "lts" "version": "lts"

49
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@ -0,0 +1,49 @@
name: "🕷️ Bug report"
description: Report errors or unexpected behavior
labels:
- bug
body:
- type: markdown
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
- type: input
attributes:
label: Dify version
placeholder: 0.3.21
description: See about section in Dify console
validations:
required: true
- type: dropdown
attributes:
label: Cloud or Self Hosted
description: How / Where was Dify installed from?
multiple: true
options:
- Cloud
- Self Hosted
- Other (please specify in "Steps to Reproduce")
validations:
required: true
- type: textarea
attributes:
label: Steps to reproduce
description: We highly suggest including screenshots and a bug report log.
placeholder: Having detailed steps helps us reproduce the bug.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
placeholder: What were you expecting?
validations:
required: false
- type: textarea
attributes:
label: ❌ Actual Behavior
placeholder: What happened instead?
validations:
required: false

8
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@ -0,0 +1,8 @@
blank_issues_enabled: false
contact_links:
- name: "\U0001F4DA Dify user documentation"
url: https://docs.dify.ai/getting-started/readme
about: Documentation for users of Dify
- name: "\U0001F4DA Dify dev documentation"
url: https://docs.dify.ai/getting-started/install-self-hosted
about: Documentation for people interested in developing and contributing for Dify

View File

@ -0,0 +1,11 @@
name: "📚 Documentation Issue"
description: Report issues in our documentation
labels:
- ducumentation
body:
- type: textarea
attributes:
label: Provide a description of requested docs changes
placeholder: Briefly describe which document needs to be corrected and why.
validations:
required: true

View File

@ -0,0 +1,26 @@
name: "⭐ Feature or enhancement request"
description: Propose something new.
labels:
- enhancement
body:
- type: textarea
attributes:
label: Description of the new feature / enhancement
placeholder: What is the expected behavior of the proposed feature?
validations:
required: true
- type: textarea
attributes:
label: Scenario when this would be used?
placeholder: What is the scenario this would be used? Why is this important to your workflow as a dify user?
validations:
required: true
- type: textarea
attributes:
label: Supporting information
placeholder: "Having additional evidence, data, tweets, blog posts, research, ... anything is extremely helpful. This information provides context to the scenario that may otherwise be lost."
validations:
required: false
- type: markdown
attributes:
value: Please limit one request per issue.

11
.github/ISSUE_TEMPLATE/help_wanted.yml vendored Normal file
View File

@ -0,0 +1,11 @@
name: "🤝 Help Wanted"
description: "Request help from the community"
labels:
- help-wanted
body:
- type: textarea
attributes:
label: Provide a description of the help you need
placeholder: Briefly describe what you need help with.
validations:
required: true

View File

@ -0,0 +1,46 @@
name: "🌐 Localization/Translation issue"
description: Report incorrect translations.
labels:
- translation
body:
- type: markdown
attributes:
value: Please make sure to [search for existing issues](https://github.com/langgenius/dify/issues) before filing a new one!
- type: input
attributes:
label: Dify version
placeholder: 0.3.21
description: Hover over system tray icon or look at Settings
validations:
required: true
- type: input
attributes:
label: Utility with translation issue
placeholder: Some area
description: Please input here the utility with the translation issue
validations:
required: true
- type: input
attributes:
label: 🌐 Language affected
placeholder: "German"
validations:
required: true
- type: textarea
attributes:
label: ❌ Actual phrase(s)
placeholder: What is there? Please include a screenshot as that is extremely helpful.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected phrase(s)
placeholder: What was expected?
validations:
required: true
- type: textarea
attributes:
label: Why is the current translation wrong
placeholder: Why do you feel this is incorrect?
validations:
required: true

View File

@ -1,32 +0,0 @@
---
name: "\U0001F41B Bug report"
about: Create a report to help us improve
title: ''
labels: bug
assignees: ''
---
<!--
Please provide a clear and concise description of what the bug is. Include
screenshots if needed. Please test using the latest version of the relevant
Dify packages to make sure your issue has not already been fixed.
-->
Dify version: Cloud | Self Host
## Steps To Reproduce
<!--
Your bug will get fixed much faster if we can run your code and it doesn't
have dependencies other than Dify. Issues without reproduction steps or
code examples may be immediately closed as not actionable.
-->
1.
2.
## The current behavior
## The expected behavior

View File

@ -1,20 +0,0 @@
---
name: "\U0001F680 Feature request"
about: Suggest an idea for this project
title: ''
labels: enhancement
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@ -1,10 +0,0 @@
---
name: "\U0001F914 Questions and Help"
about: Ask a usage or consultation question
title: ''
labels: ''
assignees: ''
---

38
.github/workflows/api-unit-tests.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: Run Pytest
on:
pull_request:
branches:
- main
push:
branches:
- deploy/dev
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Cache pip dependencies
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r api/requirements.txt
- name: Run pytest
run: pytest api/tests/unit_tests

View File

@ -20,7 +20,8 @@ def check_file_for_chinese_comments(file_path):
def main(): def main():
has_chinese = False has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py'] 'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
'prompts.py']
for root, _, files in os.walk("."): for root, _, files in os.walk("."):
for file in files: for file in files:

4
.gitignore vendored
View File

@ -144,9 +144,11 @@ docker/volumes/app/storage/*
docker/volumes/db/data/* docker/volumes/db/data/*
docker/volumes/redis/data/* docker/volumes/redis/data/*
docker/volumes/weaviate/* docker/volumes/weaviate/*
docker/volumes/qdrant/*
sdks/python-client/build sdks/python-client/build
sdks/python-client/dist sdks/python-client/dist
sdks/python-client/dify_client.egg-info sdks/python-client/dify_client.egg-info
.vscode/ .vscode/*
!.vscode/launch.json

27
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,27 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Flask",
"type": "python",
"request": "launch",
"module": "flask",
"env": {
"FLASK_APP": "api/app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"args": [
"run",
"--host=0.0.0.0",
"--port=5001",
"--debug"
],
"jinja": true,
"justMyCode": true
}
]
}

View File

@ -53,9 +53,9 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r
## Community channels ## Community channels
Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/AhzKf7dNgk). We are here to help! Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help!
### i18n (Internationalization) Support ### i18n (Internationalization) Support
We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know. We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know.
Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information. Also check out the [Frontend i18n README]((web/i18n/README_EN.md)) for more information.

View File

@ -16,15 +16,15 @@
## 本地开发 ## 本地开发
要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose 堆栈 要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose。
### Fork存储库 ### Fork存储库
您需要 fork [存储](https://github.com/langgenius/dify)。 您需要 fork [Git 仓](https://github.com/langgenius/dify)。
### 克隆存储库 ### 克隆存储库
克隆您在 GitHub 上 fork 的存储库: 克隆您在 GitHub 上 fork 的库:
``` ```
git clone git@github.com:<github_username>/dify.git git clone git@github.com:<github_username>/dify.git

View File

@ -52,4 +52,4 @@ git clone git@github.com:<github_username>/dify.git
## コミュニティチャンネル ## コミュニティチャンネル
お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/AhzKf7dNgk)に参加してください。私たちがお手伝いします! お困りですか?何か質問がありますか? [Discord Community サーバ](https://discord.gg/j3XRWSPBf7) に参加してください。私たちがお手伝いします!

View File

@ -16,18 +16,31 @@ Out-of-the-box web sites supporting form mode and chat conversation mode
A single API encompassing plugin capabilities, context enhancement, and more, saving you backend coding effort A single API encompassing plugin capabilities, context enhancement, and more, saving you backend coding effort
Visual data analysis, log review, and annotation for applications Visual data analysis, log review, and annotation for applications
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
## Highlighted Features ## Highlighted Features
**1. LLMs support:** Choose capabilities based on different models when building your Dify AI apps. Dify is compatible with Langchain, meaning it will support various LLMs. Currently supported: **1. LLMs support:** Choose capabilities based on different models when building your Dify AI apps. Dify is compatible with Langchain, meaning it will support various LLMs. Currently supported:
>* OpenAI: GPT-4, GPT-3.5-turbo, GPT-3.5-turbo-16k, text-davinci-003 - [x] **OpenAI**: GPT4, GPT3.5-turbo, GPT3.5-turbo-16k, text-davinci-003
>* Azure OpenAI Service - [x] **Azure OpenAI Service**
>* Anthropic: Claude2, Claude-instant - [x] **Anthropic**: Claude2, Claude-instant
>* Hugging Face Hub (coming soon) - [x] **Replicate**
- [x] **Hugging Face Hub**
- [x] **ChatGLM**
- [x] **Llama2**
- [x] **MiniMax**
- [x] **Spark**
- [x] **Wenxin**
- [x] **Tongyi**
We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)): We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)):
* 1000 free Claude model queries to build Claude-powered apps * 600,000 free Claude model tokens to build Claude-powered apps
* 200 free OpenAI queries to build OpenAI-based apps * 200 free OpenAI queries to build OpenAI-based apps
**2. Visual orchestration:** Build an AI app in minutes by writing and debugging prompts visually. **2. Visual orchestration:** Build an AI app in minutes by writing and debugging prompts visually.
**3. Text embedding:** Fully automated text preprocessing embeds your data as context without complex concepts. Supports PDF, TXT, and syncing data from Notion, webpages, APIs. **3. Text embedding:** Fully automated text preprocessing embeds your data as context without complex concepts. Supports PDF, TXT, and syncing data from Notion, webpages, APIs.
@ -55,7 +68,7 @@ Visit [Dify.ai](https://dify.ai)
Before installing Dify, make sure your machine meets the following minimum system requirements: Before installing Dify, make sure your machine meets the following minimum system requirements:
- CPU >= 1 Core - CPU >= 2 Core
- RAM >= 4GB - RAM >= 4GB
### Quick Start ### Quick Start
@ -86,8 +99,6 @@ Features under development:
We will support more datasets, including text, webpages, and even Notion content. Users can build AI applications based on their own data sources. We will support more datasets, including text, webpages, and even Notion content. Users can build AI applications based on their own data sources.
- **Plugins**, introducing ChatGPT Plugin-standard plugins for applications, or using Dify-produced plugins - **Plugins**, introducing ChatGPT Plugin-standard plugins for applications, or using Dify-produced plugins
We will release plugins complying with ChatGPT standard, or Dify's own plugins to enable more capabilities in applications. We will release plugins complying with ChatGPT standard, or Dify's own plugins to enable more capabilities in applications.
- **Open-source models**, e.g. adopting Llama as a model provider or for further fine-tuning
We will work with excellent open-source models like Llama, by providing them as model options in our platform, or using them for further fine-tuning.
## Q&A ## Q&A

View File

@ -17,19 +17,29 @@
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作 - 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
- 可视化的对应用进行数据分析,查阅日志或进行标注 - 可视化的对应用进行数据分析,查阅日志或进行标注
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
## 核心能力 ## 核心能力
1. **模型支持:** 你可以在 Dify 上选择基于不同模型的能力来开发你的 AI 应用。Dify 兼容 Langchain这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商: 1. **模型支持:** 你可以在 Dify 上选择基于不同模型的能力来开发你的 AI 应用。Dify 兼容 Langchain这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商:
> * **OpenAI**GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003 - [x] **OpenAI**GPT4、GPT3.5-turbo、GPT3.5-turbo-16k、text-davinci-003
> * **Azure OpenAI Service** - [x] **Azure OpenAI Service**
> * **Anthropic**Claude2、Claude-instant - [x] **Anthropic**Claude2、Claude-instant
> * **Hugging Face Hub**(即将推出) - [x] **Replicate**
- [x] **Hugging Face Hub**
- [x] **ChatGLM**
- [x] **Llama2**
- [x] **MiniMax**
- [x] **讯飞星火大模型**
- [x] **文心一言**
- [x] **通义千问**
我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用): 我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用):
* 1000 次 Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用 * 60 万 Tokens Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
* 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用 * 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用
* 300 万 讯飞星火大模型 Token 的调用额度,用于创建基于讯飞星火大模型的 AI 应用
* 100 万 MiniMax Token 的调用额度,用于创建基于 MiniMax 模型的 AI 应用
2. **可视化编排 Prompt** 通过界面化编写 prompt 并调试,只需几分钟即可发布一个 AI 应用。 2. **可视化编排 Prompt** 通过界面化编写 prompt 并调试,只需几分钟即可发布一个 AI 应用。
3. **文本 Embedding 处理(数据集)**:全自动完成文本预处理,使用你的数据作为上下文,无需理解晦涩的概念和技术处理。支持 PDF、txt 等文件格式,支持从 Notion、网页、API 同步数据。 3. **文本 Embedding 处理(数据集)**:全自动完成文本预处理,使用你的数据作为上下文,无需理解晦涩的概念和技术处理。支持 PDF、txt 等文件格式,支持从 Notion、网页、API 同步数据。
4. **基于 API 开发:** 后端即服务。您可以直接访问网页应用,也可以接入 API 集成到您的应用中,无需关注复杂的后端架构和部署过程。 4. **基于 API 开发:** 后端即服务。您可以直接访问网页应用,也可以接入 API 集成到您的应用中,无需关注复杂的后端架构和部署过程。
@ -53,7 +63,7 @@
在安装 Dify 之前,请确保您的机器满足以下最低系统要求: 在安装 Dify 之前,请确保您的机器满足以下最低系统要求:
- CPU >= 1 Core - CPU >= 2 Core
- RAM >= 4GB - RAM >= 4GB
### 快速启动 ### 快速启动
@ -82,8 +92,6 @@ docker compose up -d
- **数据集**支持更多的数据集通过网页、API 同步内容。用户可以根据自己的数据源构建 AI 应用程序。 - **数据集**支持更多的数据集通过网页、API 同步内容。用户可以根据自己的数据源构建 AI 应用程序。
- **插件**,我们将发布符合 ChatGPT 标准的插件,支持更多 Dify 自己的插件,支持用户自定义插件能力,以在应用程序中启用更多功能,例如以支持以目标为导向的分解推理任务。 - **插件**,我们将发布符合 ChatGPT 标准的插件,支持更多 Dify 自己的插件,支持用户自定义插件能力,以在应用程序中启用更多功能,例如以支持以目标为导向的分解推理任务。
- **开源模型支持**,支持 Hugging face Hub 上的开源模型。例如采用 Llama 作为模型提供者,或进行进一步的微调
我们将与优秀的开源模型合作,通过在我们的平台中提供它们作为模型选项,或使用它们进行进一步的微调。
## Q&A ## Q&A

View File

@ -32,7 +32,7 @@ Visita [Dify.ai](https://dify.ai)
Antes de instalar Dify, asegúrate de que tu máquina cumple con los siguientes requisitos mínimos del sistema: Antes de instalar Dify, asegúrate de que tu máquina cumple con los siguientes requisitos mínimos del sistema:
- CPU >= 1 Core - CPU >= 2 Core
- RAM >= 4GB - RAM >= 4GB
### Inicio rápido ### Inicio rápido

View File

@ -50,25 +50,7 @@ S3_REGION=your-region
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Cookie configuration # Vector database configuration, support: weaviate, qdrant, milvus
COOKIE_HTTPONLY=true
COOKIE_SAMESITE=None
COOKIE_SECURE=true
# Session configuration
SESSION_PERMANENT=true
SESSION_USE_SIGNER=true
## support redis, sqlalchemy
SESSION_TYPE=redis
# session redis configuration
SESSION_REDIS_HOST=localhost
SESSION_REDIS_PORT=6379
SESSION_REDIS_PASSWORD=difyai123456
SESSION_REDIS_DB=2
# Vector database configuration, support: weaviate, qdrant
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Weaviate configuration # Weaviate configuration
@ -77,9 +59,16 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100 WEAVIATE_BATCH_SIZE=100
# Qdrant configuration, use `path:` prefix for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode # Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=path:storage/qdrant QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=your-qdrant-api-key QDRANT_API_KEY=difyai123456
# Milvus configuration
MILVUS_HOST=127.0.0.1
MILVUS_PORT=19530
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# Mail configuration, support: resend # Mail configuration, support: resend
MAIL_TYPE= MAIL_TYPE=
@ -117,10 +106,12 @@ HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
HOSTED_ANTHROPIC_ENABLED=false HOSTED_ANTHROPIC_ENABLED=false
HOSTED_ANTHROPIC_API_BASE= HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY= HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000 HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
HOSTED_ANTHROPIC_PAID_ENABLED=false HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID= HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1 HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1000000
HOSTED_ANTHROPIC_PAID_MIN_QUANTITY=20
HOSTED_ANTHROPIC_PAID_MAX_QUANTITY=100
STRIPE_API_KEY= STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET= STRIPE_WEBHOOK_SECRET=

View File

@ -1,7 +1,18 @@
FROM python:3.10-slim # packages install stage
FROM python:3.10-slim AS base
LABEL maintainer="takatost@gmail.com" LABEL maintainer="takatost@gmail.com"
RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ python3-dev libc-dev libffi-dev
COPY requirements.txt /requirements.txt
RUN pip install --prefix=/pkg -r requirements.txt
# build stage
FROM python:3.10-slim AS builder
ENV FLASK_APP app.py ENV FLASK_APP app.py
ENV EDITION SELF_HOSTED ENV EDITION SELF_HOSTED
ENV DEPLOY_ENV PRODUCTION ENV DEPLOY_ENV PRODUCTION
@ -15,15 +26,17 @@ EXPOSE 5001
WORKDIR /app/api WORKDIR /app/api
RUN apt-get update && \ RUN apt-get update \
apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev && apt-get install -y --no-install-recommends bash curl wget vim nodejs \
&& apt-get autoremove \
COPY requirements.txt /app/api/requirements.txt && rm -rf /var/lib/apt/lists/*
RUN pip install -r requirements.txt
COPY --from=base /pkg /usr/local
COPY . /app/api/ COPY . /app/api/
RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')"
ENV TRANSFORMERS_OFFLINE true
COPY docker/entrypoint.sh /entrypoint.sh COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh RUN chmod +x /entrypoint.sh

View File

@ -52,11 +52,13 @@
flask run --host 0.0.0.0 --port=5001 --debug flask run --host 0.0.0.0 --port=5001 --debug
``` ```
7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis... 7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis...
8. If you need to debug local async processing, you can run `celery -A app.celery worker -Q dataset,generation,mail`, celery can do dataset importing and other async tasks. 8. If you need to debug local async processing, you can run `celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`, celery can do dataset importing and other async tasks.
8. Start frontend: 8. Start frontend
You can start the frontend by running `npm install && npm run dev` in web/ folder, or you can use docker to start the frontend, for example:
``` ```
docker run -it -d --platform linux/amd64 -p 3000:3000 -e EDITION=SELF_HOSTED -e CONSOLE_URL=http://127.0.0.1:5000 --name web-self-hosted langgenius/dify-web:latest docker run -it -d --platform linux/amd64 -p 3000:3000 -e EDITION=SELF_HOSTED -e CONSOLE_URL=http://127.0.0.1:5001 --name web-self-hosted langgenius/dify-web:latest
``` ```
This will start a dify frontend, now you are all set, happy coding! This will start a dify frontend, now you are all set, happy coding!

View File

@ -1,8 +1,7 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import os import os
from datetime import datetime
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey from gevent import monkey
@ -12,12 +11,11 @@ import logging
import json import json
import threading import threading
from flask import Flask, request, Response, session from flask import Flask, request, Response
import flask_login
from flask_cors import CORS from flask_cors import CORS
from core.model_providers.providers import hosted from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_login import login_manager from extensions.ext_login import login_manager
@ -27,12 +25,10 @@ from models import model, account, dataset, web, task, source, tool
from events import event_handlers from events import event_handlers
# DO NOT REMOVE ABOVE # DO NOT REMOVE ABOVE
import core
from config import Config, CloudEditionConfig from config import Config, CloudEditionConfig
from commands import register_commands from commands import register_commands
from models.account import TenantAccountJoin, AccountStatus from services.account_service import AccountService
from models.model import Account, EndUser, App from libs.passport import PassportService
from services.account_service import TenantService
import warnings import warnings
warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", ResourceWarning)
@ -85,77 +81,33 @@ def initialize_extensions(app):
ext_redis.init_app(app) ext_redis.init_app(app)
ext_storage.init_app(app) ext_storage.init_app(app)
ext_celery.init_app(app) ext_celery.init_app(app)
ext_session.init_app(app)
ext_login.init_app(app) ext_login.init_app(app)
ext_mail.init_app(app) ext_mail.init_app(app)
ext_sentry.init_app(app) ext_sentry.init_app(app)
ext_stripe.init_app(app) ext_stripe.init_app(app)
def _create_tenant_for_account(account):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant
return tenant
# Flask-Login configuration # Flask-Login configuration
@login_manager.user_loader @login_manager.request_loader
def load_user(user_id): def load_user_from_request(request_from_flask_login):
"""Load user based on the user_id.""" """Load user based on the request."""
if request.blueprint == 'console': if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
if '.' in user_id: auth_header = request.headers.get('Authorization', '')
tenant_id, account_id = user_id.split('.') if ' ' not in auth_header:
else: raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
account_id = user_id auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')
account = db.session.query(Account).filter(Account.id == account_id).first() return AccountService.load_user(user_id)
if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()
if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
account.last_active_at = datetime.utcnow()
db.session.commit()
# Log in the user with the updated user_id
flask_login.login_user(account, remember=True)
return account
else: else:
return None return None
@login_manager.unauthorized_handler @login_manager.unauthorized_handler
def unauthorized_handler(): def unauthorized_handler():
"""Handle unauthorized requests.""" """Handle unauthorized requests."""
@ -212,6 +164,7 @@ if app.config['TESTING']:
@app.after_request @app.after_request
def after_request(response): def after_request(response):
"""Add Version headers to the response.""" """Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION']) response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV']) response.headers.add('X-Env', app.config['DEPLOY_ENV'])
return response return response

View File

@ -1,26 +1,36 @@
import datetime import datetime
import json
import math import math
import random import random
import string import string
import threading
import time import time
import uuid
import click import click
from flask import current_app from tqdm import tqdm
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from libs.password import password_pattern, valid_password, hash_password from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate from libs.helper import email as email_validate
from extensions.ext_database import db from extensions.ext_database import db
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account from models.model import Account, AppModelConfig, App
import secrets import secrets
import base64 import base64
from models.provider import Provider, ProviderType, ProviderQuotaType from models.provider import Provider, ProviderType, ProviderQuotaType, ProviderModel
@click.command('reset-password', help='Reset the account password.') @click.command('reset-password', help='Reset the account password.')
@ -102,6 +112,7 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id) tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == 'custom').delete() db.session.query(Provider).filter(Provider.provider_type == 'custom').delete()
db.session.query(ProviderModel).delete()
db.session.commit() db.session.commit()
click.echo(click.style('Congratulations! ' click.echo(click.style('Congratulations! '
@ -230,7 +241,13 @@ def clean_unused_dataset_indexes():
kw_index = IndexBuilder.get_index(dataset, 'economy') kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
if vector_index: if vector_index:
vector_index.delete() if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete() kw_index.delete()
# update document # update document
update_params = { update_params = {
@ -258,6 +275,8 @@ def sync_anthropic_hosted_providers():
click.echo(click.style('Start sync anthropic hosted providers.', fg='green')) click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0 count = 0
new_quota_limit = hosted_model_providers.anthropic.quota_limit
page = 1 page = 1
while True: while True:
try: try:
@ -265,6 +284,7 @@ def sync_anthropic_hosted_providers():
Provider.provider_name == 'anthropic', Provider.provider_name == 'anthropic',
Provider.provider_type == ProviderType.SYSTEM.value, Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value, Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.quota_limit != new_quota_limit
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100) ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
except NotFound: except NotFound:
break break
@ -272,9 +292,9 @@ def sync_anthropic_hosted_providers():
page += 1 page += 1
for provider in providers: for provider in providers:
try: try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id)) click.echo('Syncing tenant anthropic hosted provider: {}, origin: limit {}, used {}'
.format(provider.tenant_id, provider.quota_limit, provider.quota_used))
original_quota_limit = provider.quota_limit original_quota_limit = provider.quota_limit
new_quota_limit = hosted_model_providers.anthropic.quota_limit
division = math.ceil(new_quota_limit / 1000) division = math.ceil(new_quota_limit / 1000)
provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \ provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
@ -292,6 +312,412 @@ def sync_anthropic_hosted_providers():
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green')) click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
def create_qdrant_indexes():
click.echo(click.style('Start create qdrant indexes.', fg='green'))
create_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id
)
dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.SYSTEM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@click.command('update-qdrant-indexes', help='Update qdrant indexes.')
def update_qdrant_indexes():
click.echo(click.style('Start Update qdrant indexes.', fg='green'))
create_count = 0
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for dataset in datasets:
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] != 'qdrant':
try:
click.echo('Update dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.update_qdrant_dataset(dataset)
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = []
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break
datasets_result = datasets.items
page += 1
for i in range(0, len(datasets_result), 5):
threads = []
sub_datasets = datasets_result[i:i + 5]
for dataset in sub_datasets:
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'normalization_count': normalization_count
})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context():
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
# index.delete_by_group_id(dataset.id)
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
normalization_count.append(1)
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def update_app_model_configs(batch_size):
pre_prompt_template = '{{default_input}}'
user_input_form_template = {
"en-US": [
{
"paragraph": {
"label": "Query",
"variable": "default_input",
"required": False,
"default": ""
}
}
],
"zh-Hans": [
{
"paragraph": {
"label": "查询内容",
"variable": "default_input",
"required": False,
"default": ""
}
}
]
}
click.secho("Start migrate old data that the text generator can support paragraph variable.", fg='green')
total_records = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
num_batches = (total_records + batch_size - 1) // batch_size
with tqdm(total=total_records, desc="Migrating Data") as pbar:
for i in range(num_batches):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
try:
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
for data in data_batch:
# click.secho(f"Migrating data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
if data.pre_prompt is None:
data.pre_prompt = pre_prompt_template
else:
if pre_prompt_template in data.pre_prompt:
continue
data.pre_prompt += pre_prompt_template
app_data = db.session.query(App) \
.filter(App.id == data.app_id) \
.one()
account_data = db.session.query(Account) \
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
.filter(TenantAccountJoin.role == 'owner') \
.filter(TenantAccountJoin.tenant_id == app_data.tenant_id) \
.one_or_none()
if not account_data:
continue
if data.user_input_form is None or data.user_input_form == 'null':
data.user_input_form = json.dumps(user_input_form_template[account_data.interface_language])
else:
raw_json_data = json.loads(data.user_input_form)
raw_json_data.append(user_input_form_template[account_data.interface_language][0])
data.user_input_form = json.dumps(raw_json_data)
# click.secho(f"Updated data {data.id}, pre_prompt: {data.pre_prompt}, user_input_form: {data.user_input_form}", fg='green')
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
@click.command('migrate_default_input_to_dataset_query_variable')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def migrate_default_input_to_dataset_query_variable(batch_size):
click.secho("Starting...", fg='green')
total_records = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
num_batches = (total_records + batch_size - 1) // batch_size
with tqdm(total=total_records, desc="Migrating Data") as pbar:
for i in range(num_batches):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
try:
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
for data in data_batch:
config = AppModelConfig.to_dict(data)
tools = config["agent_mode"]["tools"]
dataset_exists = "dataset" in str(tools)
if not dataset_exists:
continue
user_input_form = config.get("user_input_form", [])
for form in user_input_form:
paragraph = form.get('paragraph')
if paragraph \
and paragraph.get('variable') == 'query':
data.dataset_query_variable = 'query'
break
if paragraph \
and paragraph.get('variable') == 'default_input':
data.dataset_query_variable = 'default_input'
break
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
@ -300,3 +726,8 @@ def register_commands(app):
app.cli.add_command(recreate_all_dataset_indexes) app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers) app.cli.add_command(sync_anthropic_hosted_providers)
app.cli.add_command(clean_unused_dataset_indexes) app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)

View File

@ -10,9 +10,6 @@ from extensions.ext_redis import redis_client
dotenv.load_dotenv() dotenv.load_dotenv()
DEFAULTS = { DEFAULTS = {
'COOKIE_HTTPONLY': 'True',
'COOKIE_SECURE': 'True',
'COOKIE_SAMESITE': 'None',
'DB_USERNAME': 'postgres', 'DB_USERNAME': 'postgres',
'DB_PASSWORD': '', 'DB_PASSWORD': '',
'DB_HOST': 'localhost', 'DB_HOST': 'localhost',
@ -22,10 +19,6 @@ DEFAULTS = {
'REDIS_PORT': '6379', 'REDIS_PORT': '6379',
'REDIS_DB': '0', 'REDIS_DB': '0',
'REDIS_USE_SSL': 'False', 'REDIS_USE_SSL': 'False',
'SESSION_REDIS_HOST': 'localhost',
'SESSION_REDIS_PORT': '6379',
'SESSION_REDIS_DB': '2',
'SESSION_REDIS_USE_SSL': 'False',
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize', 'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
'OAUTH_REDIRECT_INDEX_PATH': '/', 'OAUTH_REDIRECT_INDEX_PATH': '/',
'CONSOLE_WEB_URL': 'https://cloud.dify.ai', 'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
@ -36,9 +29,6 @@ DEFAULTS = {
'STORAGE_TYPE': 'local', 'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage', 'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai', 'CHECK_UPDATE_URL': 'https://updates.dify.ai',
'SESSION_TYPE': 'sqlalchemy',
'SESSION_PERMANENT': 'True',
'SESSION_USE_SIGNER': 'True',
'DEPLOY_ENV': 'PRODUCTION', 'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30, 'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_POOL_RECYCLE': 3600, 'SQLALCHEMY_POOL_RECYCLE': 3600,
@ -48,21 +38,25 @@ DEFAULTS = {
'WEAVIATE_GRPC_ENABLED': 'True', 'WEAVIATE_GRPC_ENABLED': 'True',
'WEAVIATE_BATCH_SIZE': 100, 'WEAVIATE_BATCH_SIZE': 100,
'CELERY_BACKEND': 'database', 'CELERY_BACKEND': 'database',
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'HOSTED_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False', 'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False', 'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1, 'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False', 'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000, 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
'HOSTED_ANTHROPIC_ENABLED': 'False', 'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1, 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'TENANT_DOCUMENT_COUNT': 100, 'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30 'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
'UPLOAD_FILE_BATCH_LIMIT': 5,
} }
@ -98,13 +92,12 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL') self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL') self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL') self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.13" self.CURRENT_VERSION = "0.3.26"
self.COMMIT_SHA = get_env('COMMIT_SHA') self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED" self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV') self.DEPLOY_ENV = get_env('DEPLOY_ENV')
self.TESTING = False self.TESTING = False
self.LOG_LEVEL = get_env('LOG_LEVEL') self.LOG_LEVEL = get_env('LOG_LEVEL')
self.PDF_PREVIEW = get_bool_env('PDF_PREVIEW')
# Your App secret key will be used for securely signing the session cookie # Your App secret key will be used for securely signing the session cookie
# Make sure you are changing this key for your deployment with a strong key. # Make sure you are changing this key for your deployment with a strong key.
@ -112,20 +105,6 @@ class Config:
# Alternatively you can set it with `SECRET_KEY` environment variable. # Alternatively you can set it with `SECRET_KEY` environment variable.
self.SECRET_KEY = get_env('SECRET_KEY') self.SECRET_KEY = get_env('SECRET_KEY')
# cookie settings
self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.PERMANENT_SESSION_LIFETIME = timedelta(days=7)
# session settings, only support sqlalchemy, redis
self.SESSION_TYPE = get_env('SESSION_TYPE')
self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT')
self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER')
# redis settings # redis settings
self.REDIS_HOST = get_env('REDIS_HOST') self.REDIS_HOST = get_env('REDIS_HOST')
self.REDIS_PORT = get_env('REDIS_PORT') self.REDIS_PORT = get_env('REDIS_PORT')
@ -134,14 +113,6 @@ class Config:
self.REDIS_DB = get_env('REDIS_DB') self.REDIS_DB = get_env('REDIS_DB')
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL') self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
# session redis settings
self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST')
self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT')
self.SESSION_REDIS_USERNAME = get_env('SESSION_REDIS_USERNAME')
self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD')
self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB')
self.SESSION_REDIS_USE_SSL = get_bool_env('SESSION_REDIS_USE_SSL')
# storage settings # storage settings
self.STORAGE_TYPE = get_env('STORAGE_TYPE') self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
@ -164,6 +135,14 @@ class Config:
self.QDRANT_URL = get_env('QDRANT_URL') self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# milvus setting
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = get_env('MILVUS_PORT')
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
# cors settings # cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL) 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
@ -209,7 +188,7 @@ class Config:
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY') self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE') self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT') self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID') self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA')) self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))
@ -217,23 +196,24 @@ class Config:
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE') self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT') self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED') self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE') self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY') self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT') self.HOSTED_ANTHROPIC_QUOTA_LIMIT = int(get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT'))
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID') self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA') self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = int(get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA'))
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')
# notion import setting # notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
@ -244,6 +224,10 @@ class Config:
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT') self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING') self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
# uploading settings
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
class CloudEditionConfig(Config): class CloudEditionConfig(Config):

View File

@ -16,7 +16,7 @@ model_templates = {
}, },
'model_config': { 'model_config': {
'provider': 'openai', 'provider': 'openai',
'model_id': 'text-davinci-003', 'model_id': 'gpt-3.5-turbo-instruct',
'configs': { 'configs': {
'prompt_template': '', 'prompt_template': '',
'prompt_variables': [], 'prompt_variables': [],
@ -30,7 +30,7 @@ model_templates = {
}, },
'model': json.dumps({ 'model': json.dumps({
"provider": "openai", "provider": "openai",
"name": "text-davinci-003", "name": "gpt-3.5-turbo-instruct",
"completion_params": { "completion_params": {
"max_tokens": 512, "max_tokens": 512,
"temperature": 1, "temperature": 1,
@ -38,7 +38,18 @@ model_templates = {
"presence_penalty": 0, "presence_penalty": 0,
"frequency_penalty": 0 "frequency_penalty": 0
} }
}) }),
'user_input_form': json.dumps([
{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
]),
'pre_prompt': '{{query}}'
} }
}, },
@ -93,7 +104,7 @@ demo_model_templates = {
'mode': 'completion', 'mode': 'completion',
'model_config': AppModelConfig( 'model_config': AppModelConfig(
provider='openai', provider='openai',
model_id='text-davinci-003', model_id='gpt-3.5-turbo-instruct',
configs={ configs={
'prompt_template': "Please translate the following text into {{target_language}}:\n", 'prompt_template': "Please translate the following text into {{target_language}}:\n",
'prompt_variables': [ 'prompt_variables': [
@ -129,7 +140,7 @@ demo_model_templates = {
pre_prompt="Please translate the following text into {{target_language}}:\n", pre_prompt="Please translate the following text into {{target_language}}:\n",
model=json.dumps({ model=json.dumps({
"provider": "openai", "provider": "openai",
"name": "text-davinci-003", "name": "gpt-3.5-turbo-instruct",
"completion_params": { "completion_params": {
"max_tokens": 1000, "max_tokens": 1000,
"temperature": 0, "temperature": 0,
@ -211,7 +222,7 @@ demo_model_templates = {
'mode': 'completion', 'mode': 'completion',
'model_config': AppModelConfig( 'model_config': AppModelConfig(
provider='openai', provider='openai',
model_id='text-davinci-003', model_id='gpt-3.5-turbo-instruct',
configs={ configs={
'prompt_template': "请将以下文本翻译为{{target_language}}:\n", 'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
'prompt_variables': [ 'prompt_variables': [
@ -247,7 +258,7 @@ demo_model_templates = {
pre_prompt="请将以下文本翻译为{{target_language}}:\n", pre_prompt="请将以下文本翻译为{{target_language}}:\n",
model=json.dumps({ model=json.dumps({
"provider": "openai", "provider": "openai",
"name": "text-davinci-003", "name": "gpt-3.5-turbo-instruct",
"completion_params": { "completion_params": {
"max_tokens": 1000, "max_tokens": 1000,
"temperature": 0, "temperature": 0,

View File

@ -1,4 +1,5 @@
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
import flask_restful import flask_restful
from flask_restful import Resource, fields, marshal_with from flask_restful import Resource, fields, marshal_with
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -80,6 +81,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken() api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id) setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.token = key api_token.token = key
api_token.type = self.resource_type api_token.type = self.resource_type
db.session.add(api_token) db.session.add(api_token)

View File

@ -1,9 +1,11 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import json import json
import logging
from datetime import datetime from datetime import datetime
from flask_login import login_required, current_user from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from constants.model_template import model_templates, demo_model_templates from constants.model_template import model_templates, demo_model_templates
@ -11,42 +13,16 @@ from controllers.console import api
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType from core.model_providers.model_provider_factory import ModelProviderFactory
from events.app_event import app_was_created, app_was_deleted from events.app_event import app_was_created, app_was_deleted
from libs.helper import TimestampField from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, AppModelConfig, Site from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService from services.app_model_config_service import AppModelConfigService
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}
app_detail_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'created_at': TimestampField
}
def _get_app(app_id, tenant_id): def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
@ -56,35 +32,6 @@ def _get_app(app_id, tenant_id):
class AppListApi(Resource): class AppListApi(Resource):
prompt_config_fields = {
'prompt_template': fields.String,
}
model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
'created_at': TimestampField
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
}
@setup_required @setup_required
@login_required @login_required
@ -124,12 +71,40 @@ class AppListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
try:
default_model = ModelFactory.get_text_generation_model(
tenant_id=current_user.current_tenant_id
)
except (ProviderTokenNotInitError, LLMBadRequestError):
default_model = None
except Exception as e:
logging.exception(e)
default_model = None
if args['model_config'] is not None: if args['model_config'] is not None:
# validate config # validate config
model_config_dict = args['model_config']
# get model provider
model_provider = ModelProviderFactory.get_preferred_model_provider(
current_user.current_tenant_id,
model_config_dict["model"]["provider"]
)
if not model_provider:
if not default_model:
raise ProviderNotInitializeError(
f"No Default System Reasoning Model available. Please configure "
f"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = default_model.model_provider.provider_name
model_config_dict["model"]["name"] = default_model.name
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
account=current_user, account=current_user,
config=args['model_config'] config=model_config_dict,
mode=args['mode']
) )
app = App( app = App(
@ -141,21 +116,8 @@ class AppListApi(Resource):
status='normal' status='normal'
) )
app_model_config = AppModelConfig( app_model_config = AppModelConfig()
provider="", app_model_config = app_model_config.from_model_config_dict(model_configuration)
model_id="",
configs={},
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
agent_mode=json.dumps(model_configuration['agent_mode']),
)
else: else:
if 'mode' not in args or args['mode'] is None: if 'mode' not in args or args['mode'] is None:
abort(400, message="mode is required") abort(400, message="mode is required")
@ -165,20 +127,22 @@ class AppListApi(Resource):
app = App(**model_config_template['app']) app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config']) app_model_config = AppModelConfig(**model_config_template['model_config'])
default_model = ModelFactory.get_default_model( # get model provider
tenant_id=current_user.current_tenant_id, model_provider = ModelProviderFactory.get_preferred_model_provider(
model_type=ModelType.TEXT_GENERATION current_user.current_tenant_id,
app_model_config.model_dict["provider"]
) )
if default_model: if not model_provider:
model_dict = app_model_config.model_dict if not default_model:
model_dict['provider'] = default_model.provider_name raise ProviderNotInitializeError(
model_dict['name'] = default_model.model_name f"No Default System Reasoning Model available. Please configure "
app_model_config.model = json.dumps(model_dict) f"in the Settings -> Model Provider.")
else: else:
raise ProviderNotInitializeError( model_dict = app_model_config.model_dict
f"No Text Generation Model available. Please configure a valid provider " model_dict['provider'] = default_model.model_provider.provider_name
f"in the Settings -> Model Provider.") model_dict['name'] = default_model.name
app_model_config.model = json.dumps(model_dict)
app.name = args['name'] app.name = args['name']
app.mode = args['mode'] app.mode = args['mode']
@ -214,18 +178,6 @@ class AppListApi(Resource):
class AppTemplateApi(Resource): class AppTemplateApi(Resource):
template_fields = {
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'mode': fields.String,
'model_config': fields.Nested(model_config_fields),
}
template_list_fields = {
'data': fields.List(fields.Nested(template_fields)),
}
@setup_required @setup_required
@login_required @login_required
@ -244,38 +196,6 @@ class AppTemplateApi(Resource):
class AppApi(Resource): class AppApi(Resource):
site_fields = {
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
}
app_detail_fields_with_site = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'site': fields.Nested(site_fields),
'api_base_url': fields.String,
'created_at': TimestampField
}
@setup_required @setup_required
@login_required @login_required
@ -297,7 +217,7 @@ class AppApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
app = _get_app(app_id, current_user.current_tenant_id) app = _get_app(app_id, current_user.current_tenant_id)
db.session.delete(app) db.session.delete(app)
@ -397,29 +317,6 @@ class AppApiStatus(Resource):
return app return app
class AppRateLimit(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields)
def post(self, app_id):
parser = reqparse.RequestParser()
parser.add_argument('api_rpm', type=inputs.natural, required=False, location='json')
parser.add_argument('api_rph', type=inputs.natural, required=False, location='json')
args = parser.parse_args()
app_id = str(app_id)
app = _get_app(app_id, current_user.current_tenant_id)
if args.get('api_rpm'):
app.api_rpm = args.get('api_rpm')
if args.get('api_rph'):
app.api_rph = args.get('api_rph')
app.updated_at = datetime.utcnow()
db.session.commit()
return app
class AppCopy(Resource): class AppCopy(Resource):
@staticmethod @staticmethod
def create_app_copy(app): def create_app_copy(app):
@ -439,22 +336,9 @@ class AppCopy(Resource):
@staticmethod @staticmethod
def create_app_model_config_copy(app_config, copy_app_id): def create_app_model_config_copy(app_config, copy_app_id):
copy_app_model_config = AppModelConfig( copy_app_model_config = app_config.copy()
app_id=copy_app_id, copy_app_model_config.app_id = copy_app_id
provider=app_config.provider,
model_id=app_config.model_id,
configs=app_config.configs,
opening_statement=app_config.opening_statement,
suggested_questions=app_config.suggested_questions,
suggested_questions_after_answer=app_config.suggested_questions_after_answer,
speech_to_text=app_config.speech_to_text,
more_like_this=app_config.more_like_this,
sensitive_word_avoidance=app_config.sensitive_word_avoidance,
model=app_config.model,
user_input_form=app_config.user_input_form,
pre_prompt=app_config.pre_prompt,
agent_mode=app_config.agent_mode
)
return copy_app_model_config return copy_app_model_config
@setup_required @setup_required
@ -482,16 +366,6 @@ class AppCopy(Resource):
return copy_app, 201 return copy_app, 201
class AppExport(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, app_id):
# todo
pass
api.add_resource(AppListApi, '/apps') api.add_resource(AppListApi, '/apps')
api.add_resource(AppTemplateApi, '/app-templates') api.add_resource(AppTemplateApi, '/app-templates')
api.add_resource(AppApi, '/apps/<uuid:app_id>') api.add_resource(AppApi, '/apps/<uuid:app_id>')
@ -500,4 +374,3 @@ api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon') api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon')
api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable') api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable') api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
api.add_resource(AppRateLimit, '/apps/<uuid:app_id>/rate-limit')

View File

@ -2,8 +2,8 @@
import logging import logging
from flask import request from flask import request
from flask_login import login_required from libs.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError
import services import services
from controllers.console import api from controllers.console import api

View File

@ -5,7 +5,7 @@ from typing import Generator, Union
import flask_login import flask_login
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_login import login_required from libs.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
@ -39,9 +39,10 @@ class CompletionMessageApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] != 'blocking' streaming = args['response_mode'] != 'blocking'
@ -115,6 +116,7 @@ class ChatMessageApi(Resource):
parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] != 'blocking' streaming = args['response_mode'] != 'blocking'

View File

@ -1,8 +1,9 @@
from datetime import datetime from datetime import datetime
import pytz import pytz
from flask_login import login_required, current_user from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal_with from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from sqlalchemy import or_, func from sqlalchemy import or_, func
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
@ -12,107 +13,14 @@ from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app import _get_app
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from libs.helper import TimestampField, datetime_string, uuid_value from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
conversation_message_detail_fields, conversation_with_summary_pagination_fields
from libs.helper import datetime_string
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Message, MessageAnnotation, Conversation from models.model import Message, MessageAnnotation, Conversation
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
feedback_stat_fields = {
'like': fields.Integer,
'dislike': fields.Integer
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'model': fields.Raw,
'user_input_form': fields.Raw,
'pre_prompt': fields.String,
'agent_mode': fields.Raw,
}
class CompletionConversationApi(Resource): class CompletionConversationApi(Resource):
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]['text'] if value else ''
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
simple_message_detail_fields = {
'inputs': fields.Raw,
'query': fields.String,
'message': MessageTextField,
'answer': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(),
'from_account_id': fields.String,
'read_at': TimestampField,
'created_at': TimestampField,
'annotation': fields.Nested(annotation_fields, allow_null=True),
'model_config': fields.Nested(simple_model_config_fields),
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
@setup_required @setup_required
@login_required @login_required
@ -190,21 +98,11 @@ class CompletionConversationApi(Resource):
class CompletionConversationDetailApi(Resource): class CompletionConversationDetailApi(Resource):
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'model_config': fields.Nested(model_config_fields),
'message': fields.Nested(message_detail_fields, attribute='first_message'),
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(conversation_detail_fields) @marshal_with(conversation_message_detail_fields)
def get(self, app_id, conversation_id): def get(self, app_id, conversation_id):
app_id = str(app_id) app_id = str(app_id)
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
@ -233,44 +131,11 @@ class CompletionConversationDetailApi(Resource):
class ChatConversationApi(Resource): class ChatConversationApi(Resource):
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(simple_model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(conversation_pagination_fields) @marshal_with(conversation_with_summary_pagination_fields)
def get(self, app_id): def get(self, app_id):
app_id = str(app_id) app_id = str(app_id)
@ -355,19 +220,6 @@ class ChatConversationApi(Resource):
class ChatConversationDetailApi(Resource): class ChatConversationDetailApi(Resource):
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
@setup_required @setup_required
@login_required @login_required

View File

@ -1,4 +1,5 @@
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from controllers.console import api from controllers.console import api

View File

@ -3,7 +3,7 @@ import logging
from typing import Union, Generator from typing import Union, Generator
from flask import Response, stream_with_context from flask import Response, stream_with_context
from flask_login import current_user, login_required from flask_login import current_user
from flask_restful import Resource, reqparse, marshal_with, fields from flask_restful import Resource, reqparse, marshal_with, fields
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -16,7 +16,9 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField from libs.login import login_required
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db from extensions.ext_database import db
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
@ -26,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.message_service import MessageService from services.message_service import MessageService
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
class ChatMessageListApi(Resource): class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = { message_infinite_scroll_pagination_fields = {

View File

@ -1,14 +1,14 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import json
from flask import request from flask import request
from flask_restful import Resource from flask_restful import Resource
from flask_login import login_required, current_user from flask_login import current_user
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app import _get_app
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from events.app_event import app_model_config_was_updated from events.app_event import app_model_config_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from models.model import AppModelConfig from models.model import AppModelConfig
@ -30,25 +30,14 @@ class ModelConfigResource(Resource):
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
account=current_user, account=current_user,
config=request.json config=request.json,
mode=app_model.mode
) )
new_app_model_config = AppModelConfig( new_app_model_config = AppModelConfig(
app_id=app_model.id, app_id=app_model.id,
provider="",
model_id="",
configs={},
opening_statement=model_configuration['opening_statement'],
suggested_questions=json.dumps(model_configuration['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']),
speech_to_text=json.dumps(model_configuration['speech_to_text']),
more_like_this=json.dumps(model_configuration['more_like_this']),
sensitive_word_avoidance=json.dumps(model_configuration['sensitive_word_avoidance']),
model=json.dumps(model_configuration['model']),
user_input_form=json.dumps(model_configuration['user_input_form']),
pre_prompt=model_configuration['pre_prompt'],
agent_mode=json.dumps(model_configuration['agent_mode']),
) )
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
db.session.add(new_app_model_config) db.session.add(new_app_model_config)
db.session.flush() db.session.flush()

View File

@ -1,32 +1,18 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from flask_login import login_required, current_user from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal_with from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
from controllers.console import api from controllers.console import api
from controllers.console.app import _get_app from controllers.console.app import _get_app
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from fields.app_fields import app_site_fields
from libs.helper import supported_language from libs.helper import supported_language
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Site from models.model import Site
app_site_fields = {
'app_id': fields.String,
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}
def parse_app_site_args(): def parse_app_site_args():
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -80,6 +66,13 @@ class AppSite(Resource):
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)
if attr_name == 'title':
app_model.name = value
elif attr_name == 'icon':
app_model.icon = value
elif attr_name == 'icon_background':
app_model.icon_background = value
db.session.commit() db.session.commit()
return site return site

View File

@ -4,7 +4,8 @@ from datetime import datetime
import pytz import pytz
from flask import jsonify from flask import jsonify
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from controllers.console import api from controllers.console import api

View File

@ -16,26 +16,25 @@ from services.account_service import RegisterService
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='args') parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
parser.add_argument('email', type=email, required=True, nullable=False, location='args') parser.add_argument('email', type=email, required=False, nullable=True, location='args')
parser.add_argument('token', type=str, required=True, nullable=False, location='args') parser.add_argument('token', type=str, required=True, nullable=False, location='args')
args = parser.parse_args() args = parser.parse_args()
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token']) workspaceId = args['workspace_id']
reg_email = args['email']
token = args['token']
tenant = db.session.query(Tenant).filter( invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
Tenant.id == args['workspace_id'],
Tenant.status == 'normal'
).first()
return {'is_valid': account is not None, 'workspace_name': tenant.name} return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
class ActivateApi(Resource): class ActivateApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('workspace_id', type=str, required=True, nullable=False, location='json') parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
parser.add_argument('email', type=email, required=True, nullable=False, location='json') parser.add_argument('email', type=email, required=False, nullable=True, location='json')
parser.add_argument('token', type=str, required=True, nullable=False, location='json') parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json') parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json') parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
@ -44,12 +43,13 @@ class ActivateApi(Resource):
parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json') parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
account = RegisterService.get_account_if_token_valid(args['workspace_id'], args['email'], args['token']) invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
if account is None: if invitation is None:
raise AlreadyActivateError() raise AlreadyActivateError()
RegisterService.revoke_token(args['workspace_id'], args['email'], args['token']) RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
account = invitation['account']
account.name = args['name'] account.name = args['name']
# generate password salt # generate password salt

View File

@ -1,13 +1,13 @@
import logging import logging
from datetime import datetime
from typing import Optional
import flask_login
import requests import requests
from flask import request, redirect, current_app, session from flask import request, redirect, current_app
from flask_login import current_user, login_required from flask_login import current_user
from flask_restful import Resource from flask_restful import Resource
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from controllers.console import api from controllers.console import api
from ..setup import setup_required from ..setup import setup_required
@ -42,15 +42,34 @@ class OAuthDataSource(Resource):
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal': if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET') internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
oauth_provider.save_internal_access_token(internal_secret) oauth_provider.save_internal_access_token(internal_secret)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success') return { 'data': '' }
else: else:
auth_url = oauth_provider.get_authorization_url() auth_url = oauth_provider.get_authorization_url()
return redirect(auth_url) return { 'data': auth_url }, 200
class OAuthDataSourceCallback(Resource): class OAuthDataSourceCallback(Resource):
def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
if 'code' in request.args:
code = request.args.get('code')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&code={code}')
elif 'error' in request.args:
error = request.args.get('error')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error={error}')
else:
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error=Access denied')
class OAuthDataSourceBinding(Resource):
def get(self, provider: str): def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
@ -66,12 +85,7 @@ class OAuthDataSourceCallback(Resource):
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {'error': 'OAuth data source process failed'}, 400 return {'error': 'OAuth data source process failed'}, 400
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success') return {'result': 'success'}, 200
elif 'error' in request.args:
error = request.args.get('error')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source={error}')
else:
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=access_denied')
class OAuthDataSourceSync(Resource): class OAuthDataSourceSync(Resource):
@ -98,4 +112,5 @@ class OAuthDataSourceSync(Resource):
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>') api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>') api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync') api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')

View File

@ -6,7 +6,6 @@ from flask_restful import Resource, reqparse
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from libs.helper import email from libs.helper import email
from libs.password import valid_password from libs.password import valid_password
@ -37,12 +36,12 @@ class LoginApi(Resource):
except Exception: except Exception:
pass pass
flask_login.login_user(account, remember=args['remember_me'])
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)
# todo: return the user info # todo: return the user info
token = AccountService.get_account_jwt_token(account)
return {'result': 'success'} return {'result': 'success', 'data': token}
class LogoutApi(Resource): class LogoutApi(Resource):

View File

@ -2,9 +2,8 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
import flask_login
import requests import requests
from flask import request, redirect, current_app, session from flask import request, redirect, current_app
from flask_restful import Resource from flask_restful import Resource
from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth
@ -75,12 +74,11 @@ class OAuthCallback(Resource):
account.initialized_at = datetime.utcnow() account.initialized_at = datetime.utcnow()
db.session.commit() db.session.commit()
# login user
session.clear()
flask_login.login_user(account, remember=True)
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success') token = AccountService.get_account_jwt_token(account)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:

View File

@ -2,9 +2,10 @@ import datetime
import json import json
from cachetools import TTLCache from cachetools import TTLCache
from flask import request, current_app from flask import request
from flask_login import login_required, current_user from flask_login import current_user
from flask_restful import Resource, marshal_with, fields, reqparse, marshal from libs.login import login_required
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
@ -13,7 +14,7 @@ from controllers.console.wraps import account_initialization_required
from core.data_loader.loader.notion import NotionLoader from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
from models.dataset import Document from models.dataset import Document
from models.source import DataSourceBinding from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
@ -21,43 +22,8 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
cache = TTLCache(maxsize=None, ttl=30) cache = TTLCache(maxsize=None, ttl=30)
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm']
PREVIEW_WORDS_LIMIT = 3000
class DataSourceApi(Resource): class DataSourceApi(Resource):
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)),
'total': fields.Integer
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}
@setup_required @setup_required
@login_required @login_required
@ -134,28 +100,6 @@ class DataSourceApi(Resource):
class DataSourceNotionListApi(Resource): class DataSourceNotionListApi(Resource):
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'is_bound': fields.Boolean,
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
}
@setup_required @setup_required
@login_required @login_required

View File

@ -1,7 +1,11 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from flask import request import flask_restful
from flask_login import login_required, current_user from flask import request, current_app
from flask_restful import Resource, reqparse, fields, marshal, marshal_with from flask_login import current_user
from controllers.console.apikey import api_key_list, api_key_fields
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
import services import services
from controllers.console import api from controllers.console import api
@ -10,40 +14,16 @@ from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory from core.model_providers.models.entity.model_params import ModelType
from libs.helper import TimestampField from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, Document from models.dataset import DocumentSegment, Document
from models.model import UploadFile from models.model import UploadFile, ApiToken
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'provider': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'app_count': fields.Integer,
'document_count': fields.Integer,
'word_count': fields.Integer,
'created_by': fields.String,
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,
"created_by": fields.String,
"created_at": TimestampField
}
def _validate_name(name): def _validate_name(name):
@ -74,8 +54,29 @@ class DatasetListApi(Resource):
datasets, total = DatasetService.get_datasets(page, limit, provider, datasets, total = DatasetService.get_datasets(page, limit, provider,
current_user.current_tenant_id, current_user) current_user.current_tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
# if len(valid_model_list) == 0:
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider "
# f"in the Settings -> Model Provider.")
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
else:
item['embedding_available'] = True
response = { response = {
'data': marshal(datasets, dataset_detail_fields), 'data': data,
'has_more': len(datasets) == limit, 'has_more': len(datasets) == limit,
'limit': limit, 'limit': limit,
'total': total, 'total': total,
@ -100,15 +101,6 @@ class DatasetListApi(Resource):
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -131,20 +123,40 @@ class DatasetApi(Resource):
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
try: try:
DatasetService.check_dataset_permission( DatasetService.check_dataset_permission(
dataset, current_user) dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
return marshal(dataset, dataset_detail_fields), 200 # check embedding setting
provider_service = ProviderService()
# get valid model list
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
if data['indexing_technique'] == 'high_quality':
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data['embedding_available'] = True
else:
data['embedding_available'] = False
else:
data['embedding_available'] = True
return data, 200
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, dataset_id): def patch(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, parser.add_argument('name', nullable=False,
@ -232,7 +244,11 @@ class DatasetIndexingEstimateApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
@ -250,11 +266,15 @@ class DatasetIndexingEstimateApi(Resource):
try: try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
args['process_rule'], args['doc_form']) args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif args['info_list']['data_source_type'] == 'notion_import': elif args['info_list']['data_source_type'] == 'notion_import':
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
@ -262,29 +282,21 @@ class DatasetIndexingEstimateApi(Resource):
try: try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
args['info_list']['notion_info_list'], args['info_list']['notion_info_list'],
args['process_rule'], args['doc_form']) args['process_rule'], args['doc_form'],
args['doc_language'], args['dataset_id'],
args['indexing_technique'])
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
return response, 200 return response, 200
class DatasetRelatedAppListApi(Resource): class DatasetRelatedAppListApi(Resource):
app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
}
related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
}
@setup_required @setup_required
@login_required @login_required
@ -316,24 +328,6 @@ class DatasetRelatedAppListApi(Resource):
class DatasetIndexingStatusApi(Resource): class DatasetIndexingStatusApi(Resource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
@setup_required @setup_required
@login_required @login_required
@ -353,16 +347,101 @@ class DatasetIndexingStatusApi(Resource):
DocumentSegment.status != 're_segment').count() DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
documents_status.append(marshal(document, self.document_status_fields)) documents_status.append(marshal(document, document_status_fields))
data = { data = {
'data': documents_status 'data': documents_status
} }
return data return data
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = 'dataset-'
resource_type = 'dataset'
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
keys = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
all()
return {"items": keys}
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
current_key_count = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
count()
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code='max_keys_exceeded'
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 200
class DatasetApiDeleteApi(Resource):
resource_type = 'dataset'
@setup_required
@login_required
@account_initialization_required
def delete(self, api_key_id):
api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
key = db.session.query(ApiToken). \
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id). \
first()
if key is None:
flask_restful.abort(404, message='API key not found')
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
return {'result': 'success'}, 204
class DatasetApiBaseUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
return {
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
else request.host_url.rstrip('/')) + '/v1'
}
api.add_resource(DatasetListApi, '/datasets') api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps') api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status') api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')

View File

@ -1,10 +1,10 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import random
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from flask import request from flask import request, current_app
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import desc, asc from sqlalchemy import desc, asc
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
@ -22,7 +22,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMBadRequestError LLMBadRequestError
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.helper import TimestampField from fields.document_fields import document_with_segments_fields, document_fields, \
dataset_and_document_fields, document_status_fields
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DatasetProcessRule, Dataset from models.dataset import DatasetProcessRule, Dataset
from models.dataset import Document, DocumentSegment from models.dataset import Document, DocumentSegment
@ -31,64 +32,6 @@ from services.dataset_service import DocumentService, DatasetService
from tasks.add_document_to_index_task import add_document_to_index_task from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task
dataset_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
document_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
}
document_with_segments_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'completed_segments': fields.Integer,
'total_segments': fields.Integer
}
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
@ -137,6 +80,10 @@ class GetProcessRuleApi(Resource):
req_data = request.args req_data = request.args
document_id = req_data.get('document_id') document_id = req_data.get('document_id')
# get default rules
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
if document_id: if document_id:
# get the latest process rule # get the latest process rule
document = Document.query.get_or_404(document_id) document = Document.query.get_or_404(document_id)
@ -157,11 +104,9 @@ class GetProcessRuleApi(Resource):
order_by(DatasetProcessRule.created_at.desc()). \ order_by(DatasetProcessRule.created_at.desc()). \
limit(1). \ limit(1). \
one_or_none() one_or_none()
mode = dataset_process_rule.mode if dataset_process_rule:
rules = dataset_process_rule.rules_dict mode = dataset_process_rule.mode
else: rules = dataset_process_rule.rules_dict
mode = DocumentService.DEFAULT_RULES['mode']
rules = DocumentService.DEFAULT_RULES['rules']
return { return {
'mode': mode, 'mode': mode,
@ -274,6 +219,8 @@ class DatasetDocumentListApi(Resource):
parser.add_argument('duplicate', type=bool, nullable=False, location='json') parser.add_argument('duplicate', type=bool, nullable=False, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
if not dataset.indexing_technique and not args['indexing_technique']: if not dataset.indexing_technique and not args['indexing_technique']:
@ -282,15 +229,6 @@ class DatasetDocumentListApi(Resource):
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
@ -307,11 +245,6 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'documents': fields.List(fields.Nested(document_fields)),
'batch': fields.String
}
@setup_required @setup_required
@login_required @login_required
@ -328,16 +261,20 @@ class DatasetInitApi(Resource):
parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
if args['indexing_technique'] == 'high_quality':
try: try:
ModelFactory.get_embedding_model( ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id tenant_id=current_user.current_tenant_id
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(args)
@ -406,11 +343,14 @@ class DocumentIndexingEstimateApi(DocumentResource):
try: try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file],
data_process_rule_dict) data_process_rule_dict, None,
'English', dataset_id)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
return response return response
@ -473,46 +413,34 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
data_process_rule_dict) data_process_rule_dict, None,
'English', dataset_id)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")
elif dataset.data_source_type: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
elif dataset.data_source_type == 'notion_import':
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
info_list, info_list,
data_process_rule_dict) data_process_rule_dict,
None, 'English', dataset_id)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.") f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
else: else:
raise ValueError('Data source type not support') raise ValueError('Data source type not support')
return response return response
class DocumentBatchIndexingStatusApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
@setup_required @setup_required
@login_required @login_required
@ -532,7 +460,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:
document.indexing_status = 'paused' document.indexing_status = 'paused'
documents_status.append(marshal(document, self.document_status_fields)) documents_status.append(marshal(document, document_status_fields))
data = { data = {
'data': documents_status 'data': documents_status
} }
@ -540,20 +468,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
class DocumentIndexingStatusApi(DocumentResource): class DocumentIndexingStatusApi(DocumentResource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
@setup_required @setup_required
@login_required @login_required
@ -575,8 +489,9 @@ class DocumentIndexingStatusApi(DocumentResource):
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused:
return marshal(document, self.document_status_fields) document.indexing_status = 'paused'
return marshal(document, document_status_fields)
class DocumentDetailApi(DocumentResource): class DocumentDetailApi(DocumentResource):
@ -709,6 +624,12 @@ class DocumentDeleteApi(DocumentResource):
def delete(self, dataset_id, document_id): def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
try: try:
@ -749,11 +670,13 @@ class DocumentMetadataApi(DocumentResource):
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
document.doc_metadata = {} document.doc_metadata = {}
if doc_type == 'others':
for key, value_type in metadata_schema.items(): document.doc_metadata = doc_metadata
value = doc_metadata.get(key) else:
if value is not None and isinstance(value, value_type): for key, value_type in metadata_schema.items():
document.doc_metadata[key] = value value = doc_metadata.get(key)
if value is not None and isinstance(value, value_type):
document.doc_metadata[key] = value
document.doc_type = doc_type document.doc_type = doc_type
document.updated_at = datetime.utcnow() document.updated_at = datetime.utcnow()
@ -769,6 +692,12 @@ class DocumentStatusApi(DocumentResource):
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
@ -832,12 +761,40 @@ class DocumentStatusApi(DocumentResource):
remove_document_from_index_task.delay(document_id) remove_document_from_index_task.delay(document_id)
return {'result': 'success'}, 200
elif action == "un_archive":
if not document.archived:
raise InvalidActionError('Document is not archived.')
# check document limit
if current_app.config['EDITION'] == 'CLOUD':
documents_count = DocumentService.get_tenant_documents_count()
total_count = documents_count + 1
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.utcnow()
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
else: else:
raise InvalidActionError() raise InvalidActionError()
class DocumentPauseApi(DocumentResource): class DocumentPauseApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id): def patch(self, dataset_id, document_id):
"""pause document.""" """pause document."""
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -867,6 +824,9 @@ class DocumentPauseApi(DocumentResource):
class DocumentRecoverApi(DocumentResource): class DocumentRecoverApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def patch(self, dataset_id, document_id): def patch(self, dataset_id, document_id):
"""recover document.""" """recover document."""
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -892,6 +852,21 @@ class DocumentRecoverApi(DocumentResource):
return {'result': 'success'}, 204 return {'result': 'success'}, 204
class DocumentLimitApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""get document limit"""
documents_count = DocumentService.get_tenant_documents_count()
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
return {
'documents_count': documents_count,
'documents_limit': tenant_document_count
}, 200
api.add_resource(GetProcessRuleApi, '/datasets/process-rule') api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi, api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents') '/datasets/<uuid:dataset_id>/documents')
@ -917,3 +892,4 @@ api.add_resource(DocumentStatusApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause') api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume') api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
api.add_resource(DocumentLimitApi, '/datasets/limit')

View File

@ -1,53 +1,30 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import uuid
from datetime import datetime from datetime import datetime
from flask import request
from flask_login import login_required, current_user from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import NotFound, Forbidden from werkzeug.exceptions import NotFound, Forbidden
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.datasets.error import InvalidActionError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from libs.login import login_required
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.enable_segment_to_index_task import enable_segment_to_index_task from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.remove_segment_from_index_task import remove_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
segment_fields = { import pandas as pd
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField
}
segment_list_response = {
'data': fields.List(fields.Nested(segment_fields)),
'has_more': fields.Boolean,
'limit': fields.Integer
}
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@ -142,7 +119,8 @@ class DatasetDocumentSegmentApi(Resource):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
@ -151,6 +129,20 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter( segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.id == str(segment_id),
@ -197,7 +189,7 @@ class DatasetDocumentSegmentApi(Resource):
# Set cache to prevent indexing the same segment multiple times # Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1) redis_client.setex(indexing_cache_key, 600, 1)
remove_segment_from_index_task.delay(segment.id) disable_segment_from_index_task.delay(segment.id)
return {'result': 'success'}, 200 return {'result': 'success'}, 200
else: else:
@ -222,6 +214,20 @@ class DatasetDocumentSegmentAddApi(Resource):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']: if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden() raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try: try:
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
@ -233,7 +239,7 @@ class DatasetDocumentSegmentAddApi(Resource):
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document) segment = SegmentService.create_segment(args, document, dataset)
return { return {
'data': marshal(segment, segment_fields), 'data': marshal(segment, segment_fields),
'doc_form': document.doc_form 'doc_form': document.doc_form
@ -250,12 +256,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound('Dataset not found.') raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document # check document
document_id = str(document_id) document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound('Document not found.') raise NotFound('Document not found.')
# check segment if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.id == str(segment_id),
@ -277,12 +299,115 @@ class DatasetDocumentSegmentUpdateApi(Resource):
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document) segment = SegmentService.update_segment(args, segment, document, dataset)
return { return {
'data': marshal(segment, segment_fields), 'data': marshal(segment, segment_fields),
'doc_form': document.doc_form 'doc_form': document.doc_form
}, 200 }, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
# Skip the first row
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
if document.doc_form == 'qa_model':
data = {'content': row[0], 'answer': row[1]}
else:
data = {'content': row[0]}
result.append(data)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, 'waiting')
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
current_user.current_tenant_id, current_user.id)
except Exception as e:
return {'error': str(e)}, 500
return {
'job_id': job_id,
'job_status': 'waiting'
}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, job_id):
job_id = str(job_id)
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
return {
'job_id': job_id,
'job_status': cache_result.decode()
}, 200
api.add_resource(DatasetDocumentSegmentListApi, api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
@ -292,3 +417,6 @@ api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi, api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>') '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
api.add_resource(DatasetDocumentSegmentBatchImportApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
'/datasets/batch_import_status/<uuid:job_id>')

View File

@ -1,45 +1,39 @@
import datetime
import hashlib
import tempfile
import chardet
import time
import uuid
from pathlib import Path
from cachetools import TTLCache from cachetools import TTLCache
from flask import request, current_app from flask import request, current_app
from flask_login import login_required, current_user
from flask_restful import Resource, marshal_with, fields import services
from werkzeug.exceptions import NotFound from libs.login import login_required
from flask_restful import Resource, marshal_with
from controllers.console import api from controllers.console import api
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
UnsupportedFileTypeError UnsupportedFileTypeError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.data_loader.file_extractor import FileExtractor from fields.file_fields import upload_config_fields, file_fields
from extensions.ext_storage import storage
from libs.helper import TimestampField from services.file_service import FileService
from extensions.ext_database import db
from models.model import UploadFile
cache = TTLCache(maxsize=None, ttl=30) cache = TTLCache(maxsize=None, ttl=30)
FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx']
PREVIEW_WORDS_LIMIT = 3000 PREVIEW_WORDS_LIMIT = 3000
class FileApi(Resource): class FileApi(Resource):
file_fields = {
'id': fields.String, @setup_required
'name': fields.String, @login_required
'size': fields.Integer, @account_initialization_required
'extension': fields.String, @marshal_with(upload_config_fields)
'mime_type': fields.String, def get(self):
'created_by': fields.String, file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
'created_at': TimestampField, batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
} return {
'file_size_limit': file_size_limit,
'batch_count_limit': batch_count_limit
}, 200
@setup_required @setup_required
@login_required @login_required
@ -56,44 +50,13 @@ class FileApi(Resource):
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
try:
file_content = file.read() upload_file = FileService.upload_file(file)
file_size = len(file_content) except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
if file_size > FILE_SIZE_LIMIT: except services.errors.file.UnsupportedFileTypeError:
message = "({file_size} > {FILE_SIZE_LIMIT})"
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]
if extension not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
# save file to storage
storage.save(file_key, file_content)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file.filename,
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by=current_user.id,
created_at=datetime.datetime.utcnow(),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
)
db.session.add(upload_file)
db.session.commit()
return upload_file, 201 return upload_file, 201
@ -103,26 +66,7 @@ class FilePreviewApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
text = FileService.get_file_preview(file_id)
key = file_id + request.path
cached_response = cache.get(key)
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
return cached_response['response']
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found")
# extract text from file
extension = upload_file.extension
if extension not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
text = FileExtractor.load(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return {'content': text} return {'content': text}

View File

@ -1,7 +1,8 @@
import logging import logging
from flask_login import login_required, current_user from flask_login import current_user
from flask_restful import Resource, reqparse, marshal, fields from libs.login import login_required
from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
import services import services
@ -11,49 +12,12 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
from libs.helper import TimestampField LLMBadRequestError
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService from services.hit_testing_service import HitTestingService
document_fields = {
'id': fields.String,
'data_source_type': fields.String,
'name': fields.String,
'doc_type': fields.String,
}
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'document': fields.Nested(document_fields),
}
hit_testing_record_fields = {
'segment': fields.Nested(segment_fields),
'score': fields.Float,
'tsne_position': fields.Raw
}
class HitTestingApi(Resource): class HitTestingApi(Resource):
@ -102,6 +66,10 @@ class HitTestingApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ValueError as e: except ValueError as e:
raise ValueError(str(e)) raise ValueError(str(e))
except Exception as e: except Exception as e:

View File

@ -31,8 +31,9 @@ class CompletionApi(InstalledAppResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
@ -92,6 +93,7 @@ class ChatApi(InstalledAppResource):
parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'

View File

@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource):
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != 'chat': if app_model.mode != 'chat':

View File

@ -1,8 +1,9 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from datetime import datetime from datetime import datetime
from flask_login import login_required, current_user from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal_with, inputs from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, inputs
from sqlalchemy import and_ from sqlalchemy import and_
from werkzeug.exceptions import NotFound, Forbidden, BadRequest from werkzeug.exceptions import NotFound, Forbidden, BadRequest
@ -10,32 +11,10 @@ from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import TimestampField from fields.installed_app_fields import installed_app_list_fields
from models.model import App, InstalledApp, RecommendedApp from models.model import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
app_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String
}
installed_app_fields = {
'id': fields.String,
'app': fields.Nested(app_fields),
'app_owner_tenant_id': fields.String,
'is_pinned': fields.Boolean,
'last_used_at': TimestampField,
'editable': fields.Boolean,
'uninstallable': fields.Boolean,
}
installed_app_list_fields = {
'installed_apps': fields.List(fields.Nested(installed_app_fields))
}
class InstalledAppsListApi(Resource): class InstalledAppsListApi(Resource):
@login_required @login_required

View File

@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs.helper import uuid_value, TimestampField from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
@ -26,25 +27,6 @@ from services.message_service import MessageService
class MessageListApi(InstalledAppResource): class MessageListApi(InstalledAppResource):
feedback_fields = {
'rating': fields.String
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):

View File

@ -24,6 +24,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions': fields.Raw, 'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw, 'more_like_this': fields.Raw,
'user_input_form': fields.Raw, 'user_input_form': fields.Raw,
} }
@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource):
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }

View File

@ -1,5 +1,6 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, fields, marshal_with from flask_restful import Resource, fields, marshal_with
from sqlalchemy import and_ from sqlalchemy import and_

View File

@ -1,4 +1,5 @@
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource from flask_restful import Resource
from functools import wraps from functools import wraps

View File

@ -1,7 +1,6 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from functools import wraps from functools import wraps
import flask_login
from flask import request, current_app from flask import request, current_app
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
@ -58,9 +57,6 @@ class SetupApi(Resource):
) )
setup() setup()
# Login
flask_login.login_user(account)
AccountService.update_last_login(account, request) AccountService.update_last_login(account, request)
return {'result': 'success'}, 201 return {'result': 'success'}, 201

View File

@ -29,6 +29,7 @@ class UniversalChatApi(UniversalChatResource):
parser.add_argument('provider', type=str, required=True, location='json') parser.add_argument('provider', type=str, required=True, location='json')
parser.add_argument('model', type=str, required=True, location='json') parser.add_argument('model', type=str, required=True, location='json')
parser.add_argument('tools', type=list, required=True, location='json') parser.add_argument('tools', type=list, required=True, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json')
args = parser.parse_args() args = parser.parse_args()
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config

View File

@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource from controllers.console.universal_chat.wraps import UniversalChatResource
from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \
conversation_with_model_config_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField,
'model_config': fields.Raw,
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class UniversalChatConversationListApi(UniversalChatResource): class UniversalChatConversationListApi(UniversalChatResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
def get(self, universal_app): def get(self, universal_app):
app_model = universal_app app_model = universal_app
@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource):
class UniversalChatConversationRenameApi(UniversalChatResource): class UniversalChatConversationRenameApi(UniversalChatResource):
@marshal_with(conversation_fields) @marshal_with(conversation_with_model_config_fields)
def post(self, universal_app, c_id): def post(self, universal_app, c_id):
app_model = universal_app app_model = universal_app
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -36,6 +36,25 @@ class UniversalChatMessageListApi(UniversalChatResource):
'created_at': TimestampField 'created_at': TimestampField
} }
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = { message_fields = {
'id': fields.String, 'id': fields.String,
'conversation_id': fields.String, 'conversation_id': fields.String,
@ -43,6 +62,7 @@ class UniversalChatMessageListApi(UniversalChatResource):
'query': fields.String, 'query': fields.String,
'answer': fields.String, 'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField, 'created_at': TimestampField,
'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)) 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
} }

View File

@ -1,4 +1,6 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import json
from flask_restful import marshal_with, fields from flask_restful import marshal_with, fields
from controllers.console import api from controllers.console import api
@ -14,6 +16,7 @@ class UniversalChatParameterApi(UniversalChatResource):
'suggested_questions': fields.Raw, 'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
} }
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
@ -21,12 +24,14 @@ class UniversalChatParameterApi(UniversalChatResource):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = universal_app app_model = universal_app
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
app_model_config.retriever_resource = json.dumps({'enabled': True})
return { return {
'opening_statement': app_model_config.opening_statement, 'opening_statement': app_model_config.opening_statement,
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
} }

View File

@ -1,7 +1,8 @@
import json import json
from functools import wraps from functools import wraps
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource from flask_restful import Resource
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
@ -46,6 +47,7 @@ def universal_chat_app_required(view=None):
suggested_questions=json.dumps([]), suggested_questions=json.dumps([]),
suggested_questions_after_answer=json.dumps({'enabled': True}), suggested_questions_after_answer=json.dumps({'enabled': True}),
speech_to_text=json.dumps({'enabled': True}), speech_to_text=json.dumps({'enabled': True}),
retriever_resource=json.dumps({'enabled': True}),
more_like_this=None, more_like_this=None,
sensitive_word_avoidance=None, sensitive_word_avoidance=None,
model=json.dumps({ model=json.dumps({

View File

@ -38,12 +38,20 @@ class StripeWebhookApi(Resource):
logging.debug(event['data']['object']['payment_status']) logging.debug(event['data']['object']['payment_status'])
logging.debug(event['data']['object']['metadata']) logging.debug(event['data']['object']['metadata'])
session = stripe.checkout.Session.retrieve(
event['data']['object']['id'],
expand=['line_items'],
)
logging.debug(session.line_items['data'][0]['quantity'])
# Fulfill the purchase... # Fulfill the purchase...
provider_checkout_service = ProviderCheckoutService() provider_checkout_service = ProviderCheckoutService()
try: try:
provider_checkout_service.fulfill_provider_order(event) provider_checkout_service.fulfill_provider_order(event, session.line_items)
except Exception as e: except Exception as e:
logging.debug(str(e)) logging.debug(str(e))
return 'success', 200 return 'success', 200

View File

@ -3,7 +3,8 @@ from datetime import datetime
import pytz import pytz
from flask import current_app, request from flask import current_app, request
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with from flask_restful import Resource, reqparse, fields, marshal_with
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError

View File

@ -1,6 +1,7 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from flask import current_app from flask import current_app
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
import services import services
@ -48,46 +49,43 @@ class MemberInviteEmailApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('email', type=str, required=True, location='json') parser.add_argument('emails', type=str, required=True, location='json', action='append')
parser.add_argument('role', type=str, required=True, default='admin', location='json') parser.add_argument('role', type=str, required=True, default='admin', location='json')
args = parser.parse_args() args = parser.parse_args()
invitee_email = args['email'] invitee_emails = args['emails']
invitee_role = args['role'] invitee_role = args['role']
if invitee_role not in ['admin', 'normal']: if invitee_role not in ['admin', 'normal']:
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
inviter = current_user inviter = current_user
invitation_results = []
try: console_web_url = current_app.config.get("CONSOLE_WEB_URL")
token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role, for invitee_email in invitee_emails:
inviter=inviter) try:
account = db.session.query(Account, TenantAccountJoin.role).join( token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role,
TenantAccountJoin, Account.id == TenantAccountJoin.account_id inviter=inviter)
).filter(Account.email == args['email']).first() account = db.session.query(Account, TenantAccountJoin.role).join(
account, role = account TenantAccountJoin, Account.id == TenantAccountJoin.account_id
account = marshal(account, account_fields) ).filter(Account.email == invitee_email).first()
account['role'] = role account, role = account
except services.errors.account.CannotOperateSelfError as e: invitation_results.append({
return {'code': 'cannot-operate-self', 'message': str(e)}, 400 'status': 'success',
except services.errors.account.NoPermissionError as e: 'email': invitee_email,
return {'code': 'forbidden', 'message': str(e)}, 403 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}'
except services.errors.account.AccountAlreadyInTenantError as e: })
return {'code': 'email-taken', 'message': str(e)}, 409 account = marshal(account, account_fields)
except Exception as e: account['role'] = role
return {'code': 'unexpected-error', 'message': str(e)}, 500 except Exception as e:
invitation_results.append({
# todo:413 'status': 'failed',
'email': invitee_email,
'message': str(e)
})
return { return {
'result': 'success', 'result': 'success',
'account': account, 'invitation_results': invitation_results,
'invite_url': '{}/activate?workspace_id={}&email={}&token={}'.format(
current_app.config.get("CONSOLE_WEB_URL"),
str(current_user.current_tenant_id),
invitee_email,
token
)
}, 201 }, 201

View File

@ -1,4 +1,5 @@
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -245,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
'enabled': v.enabled, 'enabled': v.enabled,
'min': v.min, 'min': v.min,
'max': v.max, 'max': v.max,
'default': v.default 'default': v.default,
'precision': v.precision
} }
for k, v in vars(parameter_rules).items() for k, v in vars(parameter_rules).items()
} }
@ -284,6 +286,25 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
return result return result
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
args = parser.parse_args()
provider_service = ProviderService()
result = provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
token=args['token']
)
return result
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate') api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>') api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
@ -299,3 +320,5 @@ api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url') '/workspaces/current/model-providers/<string:provider_name>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi, api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit') '/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')

View File

@ -1,4 +1,5 @@
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from controllers.console import api from controllers.console import api

View File

@ -1,5 +1,6 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden

View File

@ -1,6 +1,7 @@
import json import json
from flask_login import login_required, current_user from flask_login import current_user
from libs.login import login_required
from flask_restful import Resource, abort, reqparse from flask_restful import Resource, abort, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden

View File

@ -2,10 +2,12 @@
import logging import logging
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import current_user
from flask_restful import Resource, fields, marshal_with, reqparse, marshal from libs.login import login_required
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
from controllers.console import api from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.error import AccountNotLinkTenantError from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
@ -43,6 +45,13 @@ tenants_fields = {
'current': fields.Boolean 'current': fields.Boolean
} }
workspace_fields = {
'id': fields.String,
'name': fields.String,
'status': fields.String,
'created_at': TimestampField
}
class TenantListApi(Resource): class TenantListApi(Resource):
@setup_required @setup_required
@ -57,6 +66,38 @@ class TenantListApi(Resource):
return {'workspaces': marshal(tenants, tenants_fields)}, 200 return {'workspaces': marshal(tenants, tenants_fields)}, 200
class WorkspaceListApi(Resource):
@setup_required
@admin_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args()
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\
.paginate(page=args['page'], per_page=args['limit'])
has_more = False
if len(tenants.items) == args['limit']:
current_page_first_tenant = tenants[-1]
rest_count = db.session.query(Tenant).filter(
Tenant.created_at < current_page_first_tenant.created_at,
Tenant.id != current_page_first_tenant.id
).count()
if rest_count > 0:
has_more = True
total = db.session.query(Tenant).count()
return {
'data': marshal(tenants.items, workspace_fields),
'has_more': has_more,
'limit': args['limit'],
'page': args['page'],
'total': total
}, 200
class TenantApi(Resource): class TenantApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -92,6 +133,7 @@ class SwitchWorkspaceApi(Resource):
api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants
api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants
api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info
api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated
api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant

View File

@ -9,4 +9,4 @@ api = ExternalApi(bp)
from .app import completion, app, conversation, message, audio from .app import completion, app, conversation, message, audio
from .dataset import document from .dataset import document, segment, dataset

View File

@ -25,6 +25,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions': fields.Raw, 'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw, 'more_like_this': fields.Raw,
'user_input_form': fields.Raw, 'user_input_form': fields.Raw,
} }
@ -39,6 +40,7 @@ class AppParameterApi(AppApiResource):
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }

View File

@ -27,9 +27,11 @@ class CompletionApi(AppApiResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', type=str, location='json') parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
@ -91,6 +93,8 @@ class ChatApi(AppApiResource):
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json') parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'

View File

@ -8,25 +8,11 @@ from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource from controllers.service_api.wraps import AppApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
import services import services
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationApi(AppApiResource): class ConversationApi(AppApiResource):
@ -50,7 +36,7 @@ class ConversationApi(AppApiResource):
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource): class ConversationDetailApi(AppApiResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id): def delete(self, app_model, end_user, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()
@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource):
class ConversationRenameApi(AppApiResource): class ConversationRenameApi(AppApiResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()

View File

@ -16,6 +16,24 @@ class MessageListApi(AppApiResource):
feedback_fields = { feedback_fields = {
'rating': fields.String 'rating': fields.String
} }
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = { message_fields = {
'id': fields.String, 'id': fields.String,
@ -24,6 +42,7 @@ class MessageListApi(AppApiResource):
'query': fields.String, 'query': fields.String,
'answer': fields.String, 'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField 'created_at': TimestampField
} }

View File

@ -0,0 +1,81 @@
from flask import request
from flask_restful import reqparse, marshal
import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from libs.login import current_user
from core.model_providers.models.entity.model_params import ModelType
from fields.dataset_fields import dataset_detail_fields
from services.dataset_service import DatasetService
from services.provider_service import ProviderService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError('Name must be between 1 to 40 characters.')
return name
class DatasetApi(DatasetApiResource):
"""Resource for get datasets."""
def get(self, tenant_id):
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
provider = request.args.get('provider', default="vendor")
datasets, total = DatasetService.get_datasets(page, limit, provider,
tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
else:
item['embedding_available'] = True
response = {
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
"""Resource for datasets."""
def post(self, tenant_id):
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('indexing_technique', type=str, location='json',
choices=('high_quality', 'economy'),
help='Invalid indexing technique.')
args = parser.parse_args()
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args['name'],
indexing_technique=args['indexing_technique'],
account=current_user
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 200
api.add_resource(DatasetApi, '/datasets')

View File

@ -1,114 +1,287 @@
import datetime import json
import uuid
from flask import current_app from flask import request
from flask_restful import reqparse from flask_restful import reqparse, marshal
from sqlalchemy import desc
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services.dataset_service import services.dataset_service
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
DatasetNotInitedError NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource
from libs.login import current_user
from core.model_providers.error import ProviderTokenNotInitError from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from fields.document_fields import document_fields, document_status_fields
from models.model import UploadFile from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
from services.file_service import FileService
class DocumentListApi(DatasetApiResource): class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents.""" """Resource for documents."""
def post(self, dataset): def post(self, tenant_id, dataset_id):
"""Create document.""" """Create document by text."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json') parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('text', type=str, required=True, nullable=False, location='json') parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('doc_type', type=str, location='json') parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
parser.add_argument('doc_metadata', type=dict, location='json') parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json')
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset.indexing_technique: if not dataset:
raise DatasetNotInitedError("Dataset indexing technique must be set.") raise ValueError('Dataset is not exist.')
doc_type = args.get('doc_type') if not dataset.indexing_technique and not args['indexing_technique']:
doc_metadata = args.get('doc_metadata') raise ValueError('indexing_technique is required.')
if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: upload_file = FileService.upload_text(args.get('text'), args.get('name'))
raise ValueError('Invalid doc_type.') data_source = {
'type': 'upload_file',
# user uuid as file name 'info_list': {
file_uuid = str(uuid.uuid4()) 'data_source_type': 'upload_file',
file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt' 'file_info_list': {
'file_ids': [upload_file.id]
# save file to storage }
storage.save(file_key, args.get('text'))
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=dataset.tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=args.get('name') + '.txt',
size=len(args.get('text')),
extension='txt',
mime_type='text/plain',
created_by=dataset.created_by,
created_at=datetime.datetime.utcnow(),
used=True,
used_by=dataset.created_by,
used_at=datetime.datetime.utcnow()
)
db.session.add(upload_file)
db.session.commit()
document_data = {
'data_source': {
'type': 'upload_file',
'info': [
{
'upload_file_id': upload_file.id
}
]
} }
} }
args['data_source'] = data_source
# validate args
DocumentService.document_create_args_validate(args)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=document_data, document_data=args,
account=dataset.created_by_account, account=current_user,
dataset_process_rule=dataset.latest_process_rule, dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api' created_from='api'
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
document = documents[0] document = documents[0]
if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
document.doc_metadata = {} documents_and_batch_fields = {
'document': marshal(document, document_fields),
for key, value_type in metadata_schema.items(): 'batch': batch
value = doc_metadata.get(key) }
if value is not None and isinstance(value, value_type): return documents_and_batch_fields, 200
document.doc_metadata[key] = value
document.doc_type = doc_type
document.updated_at = datetime.datetime.utcnow()
db.session.commit()
return {'id': document.id}
class DocumentApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource):
def delete(self, dataset, document_id): """Resource for update documents."""
def post(self, tenant_id, dataset_id, document_id):
"""Update document by text."""
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('text', type=str, required=False, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if args['text']:
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
data_source = {
'type': 'upload_file',
'info_list': {
'data_source_type': 'upload_file',
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
args['original_document_id'] = str(document_id)
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}
if 'data' in request.form:
args = json.loads(request.form['data'])
if 'doc_form' not in args:
args['doc_form'] = 'text_model'
if 'doc_language' not in args:
args['doc_language'] = 'English'
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if not dataset.indexing_technique and not args['indexing_technique']:
raise ValueError('indexing_technique is required.')
# save file info
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
data_source = {
'type': 'upload_file',
'info_list': {
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
args = {}
if 'data' in request.form:
args = json.loads(request.form['data'])
if 'doc_form' not in args:
args['doc_form'] = 'text_model'
if 'doc_language' not in args:
args['doc_language'] = 'English'
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if 'file' in request.files:
# save file info
file = request.files['file']
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
data_source = {
'type': 'upload_file',
'info_list': {
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
args['original_document_id'] = str(document_id)
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentDeleteApi(DatasetApiResource):
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document.""" """Delete document."""
document_id = str(document_id) document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
@ -126,8 +299,85 @@ class DocumentApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError: except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError('Cannot delete document during indexing.') raise DocumentIndexingError('Cannot delete document during indexing.')
return {'result': 'success'}, 204 return {'result': 'success'}, 200
api.add_resource(DocumentListApi, '/documents') class DocumentListApi(DatasetApiResource):
api.add_resource(DocumentApi, '/documents/<uuid:document_id>') def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
search = request.args.get('keyword', default=None, type=str)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
query = Document.query.filter_by(
dataset_id=str(dataset_id), tenant_id=tenant_id)
if search:
search = f'%{search}%'
query = query.filter(Document.name.like(search))
query = query.order_by(desc(Document.created_at))
paginated_documents = query.paginate(
page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
response = {
'data': marshal(documents, document_fields),
'has_more': len(documents) == limit,
'limit': limit,
'total': paginated_documents.total,
'page': page
}
return response
class DocumentIndexingStatusApi(DatasetApiResource):
def get(self, tenant_id, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# get documents
documents = DocumentService.get_batch_documents(dataset_id, batch)
if not documents:
raise NotFound('Documents not found.')
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
documents_status.append(marshal(document, document_status_fields))
data = {
'data': documents_status
}
return data
api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')

View File

@ -1,20 +1,73 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = 'file_too_large'
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415
class HighQualityDatasetOnlyError(BaseHTTPException):
error_code = 'high_quality_dataset_only'
description = "Current operation only supports 'high-quality' datasets."
code = 400
class DatasetNotInitializedError(BaseHTTPException):
error_code = 'dataset_not_initialized'
description = "The dataset is still being initialized or indexing. Please wait a moment."
code = 400
class ArchivedDocumentImmutableError(BaseHTTPException): class ArchivedDocumentImmutableError(BaseHTTPException):
error_code = 'archived_document_immutable' error_code = 'archived_document_immutable'
description = "Cannot operate when document was archived." description = "The archived document is not editable."
code = 403 code = 403
class DatasetNameDuplicateError(BaseHTTPException):
error_code = 'dataset_name_duplicate'
description = "The dataset name already exists. Please modify your dataset name."
code = 409
class InvalidActionError(BaseHTTPException):
error_code = 'invalid_action'
description = "Invalid action."
code = 400
class DocumentAlreadyFinishedError(BaseHTTPException):
error_code = 'document_already_finished'
description = "The document has been processed. Please refresh the page or go to the document details."
code = 400
class DocumentIndexingError(BaseHTTPException): class DocumentIndexingError(BaseHTTPException):
error_code = 'document_indexing' error_code = 'document_indexing'
description = "Cannot operate document during indexing." description = "The document is being processed and cannot be edited."
code = 403 code = 400
class DatasetNotInitedError(BaseHTTPException): class InvalidMetadataError(BaseHTTPException):
error_code = 'dataset_not_inited' error_code = 'invalid_metadata'
description = "The dataset is still being initialized or indexing. Please wait a moment." description = "The metadata content is incorrect. Please check and verify."
code = 403 code = 400

View File

@ -0,0 +1,201 @@
from flask_login import current_user
from flask_restful import reqparse, marshal
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
def post(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
parser = reqparse.RequestParser()
parser.add_argument('status', type=str,
action='append', default=[], location='args')
parser.add_argument('keyword', type=str, default=None, location='args')
args = parser.parse_args()
status_list = args['status']
keyword = args['keyword']
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
)
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
if keyword:
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
total = query.count()
segments = query.order_by(DocumentSegment.position).all()
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'total': total
}, 200
class DatasetSegmentApi(DatasetApiResource):
def delete(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check segment
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
def post(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args['segments'], document)
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')

View File

@ -2,12 +2,14 @@
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from flask import request from flask import request, current_app
from flask_login import user_logged_in
from flask_restful import Resource from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from libs.login import _get_user
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.account import Tenant, TenantAccountJoin, Account
from models.model import ApiToken, App from models.model import ApiToken, App
@ -17,7 +19,7 @@ def validate_app_token(view=None):
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
api_token = validate_and_get_api_token('app') api_token = validate_and_get_api_token('app')
app_model = db.session.query(App).get(api_token.app_id) app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
if not app_model: if not app_model:
raise NotFound() raise NotFound()
@ -43,12 +45,24 @@ def validate_dataset_token(view=None):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
api_token = validate_and_get_api_token('dataset') api_token = validate_and_get_api_token('dataset')
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
dataset = db.session.query(Dataset).get(api_token.dataset_id) .filter(Tenant.id == api_token.tenant_id) \
if not dataset: .filter(TenantAccountJoin.tenant_id == Tenant.id) \
raise NotFound() .filter(TenantAccountJoin.role == 'owner') \
.one_or_none()
return view(dataset, *args, **kwargs) if tenant_account_join:
tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account)
user_logged_in.send(current_app._get_current_object(), user=_get_user())
else:
raise Unauthorized("Tenant owner account is not exist.")
else:
raise Unauthorized("Tenant is not exist.")
return view(api_token.tenant_id, *args, **kwargs)
return decorated return decorated
if view: if view:
@ -64,14 +78,14 @@ def validate_and_get_api_token(scope=None):
Validate and get API token. Validate and get API token.
""" """
auth_header = request.headers.get('Authorization') auth_header = request.headers.get('Authorization')
if auth_header is None: if auth_header is None or ' ' not in auth_header:
raise Unauthorized() raise Unauthorized("Authorization header must be provided and start with 'Bearer'")
auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower() auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer': if auth_scheme != 'bearer':
raise Unauthorized() raise Unauthorized("Authorization scheme must be 'Bearer'")
api_token = db.session.query(ApiToken).filter( api_token = db.session.query(ApiToken).filter(
ApiToken.token == auth_token, ApiToken.token == auth_token,
@ -79,7 +93,7 @@ def validate_and_get_api_token(scope=None):
).first() ).first()
if not api_token: if not api_token:
raise Unauthorized() raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.utcnow() api_token.last_used_at = datetime.utcnow()
db.session.commit() db.session.commit()

View File

@ -24,6 +24,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions': fields.Raw, 'suggested_questions': fields.Raw,
'suggested_questions_after_answer': fields.Raw, 'suggested_questions_after_answer': fields.Raw,
'speech_to_text': fields.Raw, 'speech_to_text': fields.Raw,
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw, 'more_like_this': fields.Raw,
'user_input_form': fields.Raw, 'user_input_form': fields.Raw,
} }
@ -38,6 +39,7 @@ class AppParameterApi(WebApiResource):
'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions': app_model_config.suggested_questions_list,
'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict,
'speech_to_text': app_model_config.speech_to_text_dict, 'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict, 'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list 'user_input_form': app_model_config.user_input_form_list
} }

View File

@ -29,8 +29,10 @@ class CompletionApi(WebApiResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json') parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'
@ -88,6 +90,8 @@ class ChatApi(WebApiResource):
parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args['response_mode'] == 'streaming' streaming = args['response_mode'] == 'streaming'

View File

@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound
from controllers.web import api from controllers.web import api
from controllers.web.error import NotChatAppError from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationListApi(WebApiResource): class ConversationListApi(WebApiResource):
@ -73,7 +59,7 @@ class ConversationApi(WebApiResource):
class ConversationRenameApi(WebApiResource): class ConversationRenameApi(WebApiResource):
@marshal_with(conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat': if app_model.mode != 'chat':
raise NotChatAppError() raise NotChatAppError()

View File

@ -29,6 +29,25 @@ class MessageListApi(WebApiResource):
'rating': fields.String 'rating': fields.String
} }
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = { message_fields = {
'id': fields.String, 'id': fields.String,
'conversation_id': fields.String, 'conversation_id': fields.String,
@ -36,6 +55,7 @@ class MessageListApi(WebApiResource):
'query': fields.String, 'query': fields.String,
'answer': fields.String, 'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField 'created_at': TimestampField
} }

View File

@ -1,3 +1,4 @@
import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
@ -52,14 +53,28 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
elif len(self.tools) == 1: elif len(self.tools) == 1:
tool = next(iter(self.tools)) tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool) tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']}) rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
# for item in rst_json:
# output += f'{item["content"]}\n'
return AgentFinish(return_values={"output": rst}, log=rst) return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps: if intermediate_steps:
_, observation = intermediate_steps[-1] _, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation) return AgentFinish(return_values={"output": observation}, log=observation)
return super().plan(intermediate_steps, callbacks, **kwargs) try:
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
async def aplan( async def aplan(
self, self,

View File

@ -45,14 +45,18 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
:return: :return:
""" """
original_max_tokens = self.llm.max_tokens original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15 self.llm.max_tokens = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages() messages = prompt.to_messages()
predicted_message = self.llm.predict_messages( try:
messages, functions=self.functions, callbacks=None predicted_message = self.llm.predict_messages(
) messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {}) function_call = predicted_message.additional_kwargs.get("function_call", {})
@ -93,6 +97,13 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
messages, functions=self.functions, callbacks=callbacks messages, functions=self.functions, callbacks=callbacks
) )
agent_decision = _parse_ai_message(predicted_message) agent_decision = _parse_ai_message(predicted_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision return agent_decision
@classmethod @classmethod

View File

@ -14,7 +14,7 @@ from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = "" moving_summary_buffer: str = ""
moving_summary_index: int = 0 moving_summary_index: int = 0
summary_llm: BaseLanguageModel summary_llm: BaseLanguageModel = None
model_instance: BaseLLM model_instance: BaseLLM
class Config: class Config:
@ -66,12 +66,12 @@ class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
return new_messages return new_messages
def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int: def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/ Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm = cast(ChatOpenAI, llm) llm = cast(ChatOpenAI, model_instance.client)
model, encoding = llm._get_encoding_model() model, encoding = llm._get_encoding_model()
if model.startswith("gpt-3.5-turbo"): if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n # every message follows <im_start>{role/name}\n{content}<im_end>\n

View File

@ -50,9 +50,13 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages() messages = prompt.to_messages()
predicted_message = self.llm.predict_messages( try:
messages, functions=self.functions, callbacks=None predicted_message = self.llm.predict_messages(
) messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {}) function_call = predicted_message.additional_kwargs.get("function_call", {})

View File

@ -10,7 +10,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
class StructuredChatOutputParser(LCStructuredChatOutputParser): class StructuredChatOutputParser(LCStructuredChatOutputParser):
def parse(self, text: str) -> Union[AgentAction, AgentFinish]: def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try: try:
action_match = re.search(r"```(.*?)\n(.*?)```?", text, re.DOTALL) action_match = re.search(r"```(\w*)\n?({.*?)```", text, re.DOTALL)
if action_match is not None: if action_match is not None:
response = json.loads(action_match.group(2).strip(), strict=False) response = json.loads(action_match.group(2).strip(), strict=False)
if isinstance(response, list): if isinstance(response, list):
@ -26,4 +26,4 @@ class StructuredChatOutputParser(LCStructuredChatOutputParser):
else: else:
return AgentFinish({"output": text}, text) return AgentFinish({"output": text}, text)
except Exception as e: except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e raise OutputParserException(f"Could not parse LLM output: {text}")

View File

@ -90,14 +90,25 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
elif len(self.dataset_tools) == 1: elif len(self.dataset_tools) == 1:
tool = next(iter(self.dataset_tools)) tool = next(iter(self.dataset_tools))
tool = cast(DatasetRetrieverTool, tool) tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']}) rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst) return AgentFinish(return_values={"output": rst}, log=rst)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
try: try:
return self.output_parser.parse(full_output) full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
try:
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException: except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "") "I don't know how to respond to that."}, "")

View File

@ -52,7 +52,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = "" moving_summary_buffer: str = ""
moving_summary_index: int = 0 moving_summary_index: int = 0
summary_llm: BaseLanguageModel summary_llm: BaseLanguageModel = None
model_instance: BaseLLM model_instance: BaseLLM
class Config: class Config:
@ -89,8 +89,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
Action specifying what tool to use. Action specifying what tool to use.
""" """
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)]) prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = [] messages = []
if prompts: if prompts:
messages = prompts[0].to_messages() messages = prompts[0].to_messages()
@ -99,16 +99,26 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if rest_tokens < 0: if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs) full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
try: try:
return self.output_parser.parse(full_output) agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException: except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "") "I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2: if len(intermediate_steps) >= 2 and self.summary_llm:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation) should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps] for _, observation in should_summary_intermediate_steps]

View File

@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor from langchain.agents import AgentExecutor as LCAgentExecutor
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@ -32,7 +34,7 @@ class AgentConfiguration(BaseModel):
strategy: PlanningStrategy strategy: PlanningStrategy
model_instance: BaseLLM model_instance: BaseLLM
tools: list[BaseTool] tools: list[BaseTool]
summary_model_instance: BaseLLM summary_model_instance: BaseLLM = None
memory: Optional[BaseChatMemory] = None memory: Optional[BaseChatMemory] = None
callbacks: Callbacks = None callbacks: Callbacks = None
max_iterations: int = 6 max_iterations: int = 6
@ -65,7 +67,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client, llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(), output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client, summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
@ -74,7 +77,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client, llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client, summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
@ -83,7 +87,8 @@ class AgentExecutor:
llm=self.configuration.model_instance.client, llm=self.configuration.model_instance.client,
tools=self.configuration.tools, tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client, summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True verbose=True
) )
elif self.configuration.strategy == PlanningStrategy.ROUTER: elif self.configuration.strategy == PlanningStrategy.ROUTER:
@ -113,6 +118,18 @@ class AgentExecutor:
return self.agent.should_use_agent(query) return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult: def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_instance.model_provider,
query
)
if not moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
configuration=self.configuration
)
agent_executor = LCAgentExecutor.from_agent_and_tools( agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent, agent=self.agent,
tools=self.configuration.tools, tools=self.configuration.tools,
@ -125,7 +142,9 @@ class AgentExecutor:
try: try:
output = agent_executor.run(query) output = agent_executor.run(query)
except Exception: except LLMError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed") logging.exception("agent_executor run failed")
output = None output = None

View File

@ -6,10 +6,11 @@ from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask from core.conversation_message_task import ConversationMessageTask
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
@ -17,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
raise_error: bool = True raise_error: bool = True
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler.""" """Initialize callback handler."""
self.model_instant = model_instant self.model_instance = model_instance
self.conversation_message_task = conversation_message_task self.conversation_message_task = conversation_message_task
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
@ -45,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks.""" """Whether to ignore chain callbacks."""
return True return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([message.content for message in messages[0]]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
@ -68,6 +84,10 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.status = 'llm_end' self._current_loop.status = 'llm_end'
if response.llm_output: if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0] completion_generation = response.generations[0][0]
if isinstance(completion_generation, ChatGeneration): if isinstance(completion_generation, ChatGeneration):
completion_message = completion_generation.message completion_message = completion_generation.message
@ -81,11 +101,15 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output: if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
def on_llm_error( def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
logging.exception(error) logging.debug("Agent on_llm_error: %s", error)
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None self._message_agent_thought = None
@ -153,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end( self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop self._message_agent_thought, self.model_instance, self._current_loop
) )
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)
@ -164,7 +188,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
"""Do nothing.""" """Do nothing."""
logging.exception(error) logging.debug("Agent on_tool_error: %s", error)
self._agent_loops = [] self._agent_loops = []
self._current_loop = None self._current_loop = None
self._message_agent_thought = None self._message_agent_thought = None
@ -184,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
) )
self.conversation_message_task.on_agent_end( self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop self._message_agent_thought, self.model_instance, self._current_loop
) )
self._agent_loops.append(self._current_loop) self._agent_loops.append(self._current_loop)

View File

@ -1,5 +1,6 @@
import json import json
import logging import logging
from json import JSONDecodeError
from typing import Any, Dict, List, Union, Optional from typing import Any, Dict, List, Union, Optional
@ -44,10 +45,15 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
input_str: str, input_str: str,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
# tool_name = serialized.get('name') tool_name: str = serialized.get('name')
input_dict = json.loads(input_str.replace("'", "\"")) dataset_id = tool_name.removeprefix('dataset-')
dataset_id = input_dict.get('dataset_id')
query = input_dict.get('query') try:
input_dict = json.loads(input_str.replace("'", "\""))
query = input_dict.get('query')
except JSONDecodeError:
query = input_str
self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query)) self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query))
def on_tool_end( def on_tool_end(
@ -58,14 +64,11 @@ class DatasetToolCallbackHandler(BaseCallbackHandler):
llm_prefix: Optional[str] = None, llm_prefix: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
# kwargs={'name': 'Search'}
# llm_prefix='Thought:'
# observation_prefix='Observation: '
# output='53 years'
pass pass
def on_tool_error( def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
"""Do nothing.""" """Do nothing."""
logging.exception(error) logging.debug("Dataset tool on_llm_error: %s", error)

View File

@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
prompt_tokens: int = 0 prompt_tokens: int = 0
completion: str = '' completion: str = ''
completion_tokens: int = 0 completion_tokens: int = 0
latency: float = 0.0

View File

@ -2,6 +2,7 @@ from typing import List
from langchain.schema import Document from langchain.schema import Document
from core.conversation_message_task import ConversationMessageTask
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment from models.dataset import DocumentSegment
@ -9,8 +10,9 @@ from models.dataset import DocumentSegment
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool.""" """Callback handler for dataset tool."""
def __init__(self, dataset_id: str) -> None: def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None:
self.dataset_id = dataset_id self.dataset_id = dataset_id
self.conversation_message_task = conversation_message_task
def on_tool_end(self, documents: List[Document]) -> None: def on_tool_end(self, documents: List[Document]) -> None:
"""Handle tool end.""" """Handle tool end."""
@ -27,3 +29,7 @@ class DatasetIndexToolCallbackHandler:
) )
db.session.commit() db.session.commit()
def return_retriever_resource_info(self, resource: List):
"""Handle return_retriever_resource_info."""
self.conversation_message_task.on_dataset_query_finish(resource)

View File

@ -1,5 +1,4 @@
import logging import logging
import time
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
self.start_at = time.perf_counter()
real_prompts = [] real_prompts = []
for message in messages[0]: for message in messages[0]:
if message.type == 'human': if message.type == 'human':
@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
self.start_at = time.perf_counter()
self.llm_message.prompt = [{ self.llm_message.prompt = [{
"role": 'user', "role": 'user',
"text": prompts[0] "text": prompts[0]
@ -63,14 +59,22 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
if not self.conversation_message_task.streaming: if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text) self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)]) if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
if 'completion_tokens' in response.llm_output['token_usage']:
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message) self.conversation_message_task.save_message(self.llm_message)
@ -89,8 +93,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Do nothing.""" """Do nothing."""
if isinstance(error, ConversationTaskStoppedException): if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming: if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.model_instance.get_num_tokens( self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)] [PromptMessage(content=self.llm_message.completion)]
) )

View File

@ -72,5 +72,5 @@ class MainChainGatherCallbackHandler(BaseCallbackHandler):
def on_chain_error( def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None: ) -> None:
logging.exception(error) logging.debug("Dataset tool on_chain_error: %s", error)
self.clear_chain_results() self.clear_chain_results()

View File

@ -1,15 +1,33 @@
import enum
import logging
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation import openai_moderation
class SensitiveWordAvoidanceRule(BaseModel):
class Type(enum.Enum):
MODERATION = "moderation"
KEYWORDS = "keywords"
type: Type
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
extra_params: dict = {}
class SensitiveWordAvoidanceChain(Chain): class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private: input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private: output_key: str = "output" #: :meta private:
sensitive_words: List[str] = [] model_instance: BaseLLM
canned_response: str = None sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
@property @property
def _chain_type(self) -> str: def _chain_type(self) -> str:
@ -31,11 +49,24 @@ class SensitiveWordAvoidanceChain(Chain):
""" """
return [self.output_key] return [self.output_key]
def _check_sensitive_word(self, text: str) -> str: def _check_sensitive_word(self, text: str) -> bool:
for word in self.sensitive_words: for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
if word in text: if word in text:
return self.canned_response return False
return text return True
def _check_moderation(self, text: str) -> bool:
moderation_model_instance = ModelFactory.get_moderation_model(
tenant_id=self.model_instance.model_provider.provider.tenant_id,
model_provider_name='openai',
model_name=openai_moderation.DEFAULT_MODEL
)
try:
return moderation_model_instance.run(text=text)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
def _call( def _call(
self, self,
@ -43,5 +74,19 @@ class SensitiveWordAvoidanceChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
text = inputs[self.input_key] text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output} if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
result = self._check_sensitive_word(text)
else:
result = self._check_moderation(text)
if not result:
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}
class SensitiveWordAvoidanceError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@ -1,31 +1,32 @@
import json
import logging import logging
import re from typing import Optional, List, Union
from typing import Optional, List, Union, Tuple
from langchain.schema import BaseMessage
from requests.exceptions import ChunkedEncodingError from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
class Completion: class Completion:
@classmethod @classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False): user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
is_override: bool = False, retriever_from: str = 'dev'):
""" """
errors: ProviderTokenNotInitError errors: ProviderTokenNotInitError
""" """
@ -76,29 +77,55 @@ class Completion:
app_model_config=app_model_config app_model_config=app_model_config
) )
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# run the final llm
try: try:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
try:
query = sensitive_word_avoidance_chain.run(query)
except SensitiveWordAvoidanceError as ex:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=ex.message
)
return
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback,
retriever_from=retriever_from
)
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
# run agent executor
agent_execute_result = None
if query_for_agent and agent_executor:
should_use_agent = agent_executor.should_use_agent(query_for_agent)
if should_use_agent:
agent_execute_result = agent_executor.run(query_for_agent)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
# run the final llm
cls.run_final_llm( cls.run_final_llm(
model_instance=final_model_instance, model_instance=final_model_instance,
mode=app.mode, mode=app.mode,
@ -107,7 +134,8 @@ class Completion:
inputs=inputs, inputs=inputs,
agent_execute_result=agent_execute_result, agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task, conversation_message_task=conversation_message_task,
memory=memory memory=memory,
fake_response=fake_response
) )
except ConversationTaskStoppedException: except ConversationTaskStoppedException:
return return
@ -116,27 +144,28 @@ class Completion:
logging.warning(f'ChunkedEncodingError: {e}') logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end() conversation_message_task.end()
return return
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
if app.mode != 'completion':
return query
return inputs.get(app_model_config.dataset_query_variable, "")
@classmethod @classmethod
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
inputs: dict,
agent_execute_result: Optional[AgentExecuteResult], agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask, conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
# When no extra pre prompt is specified, fake_response: Optional[str]):
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy != PlanningStrategy.ROUTER:
fake_response = agent_execute_result.output
# get llm prompt # get llm prompt
prompt_messages, stop_words = cls.get_main_llm_prompt( prompt_messages, stop_words = model_instance.get_prompt(
mode=mode, mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs, inputs=inputs,
agent_execute_result=agent_execute_result, query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory memory=memory
) )
@ -151,116 +180,8 @@ class Completion:
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response fake_response=fake_response
) )
return response return response
@classmethod
def get_main_llm_prompt(cls, mode: str, model: dict,
pre_prompt: str, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
""" if agent_execute_result else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{{query}}\n"
)
if agent_execute_result:
inputs['context'] = agent_execute_result.output
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
query=query,
**prompt_inputs
)
return [PromptMessage(content=prompt_content)], None
else:
messages: List[BaseMessage] = []
human_inputs = {
"query": query
}
human_message_prompt = ""
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if agent_execute_result:
human_inputs['context'] = agent_execute_result.output
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
"""
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt + query_prompt,
inputs=human_inputs
)
if memory.model_instance.model_rules.max_tokens.max:
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>\n"
human_message_prompt += histories + "\n</histories>"
human_message_prompt += query_prompt
# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)
messages.append(human_message)
for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content)
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> str: max_token_limit: int) -> str:
@ -307,13 +228,12 @@ And answer according to the language of the user's question.
max_tokens = 0 max_tokens = 0
# get prompt without memory and context # get prompt without memory and context
prompt_messages, _ = cls.get_main_llm_prompt( prompt_messages, _ = model_instance.get_prompt(
mode=mode, mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs, inputs=inputs,
agent_execute_result=None, query=query,
context=None,
memory=None memory=None
) )
@ -358,13 +278,12 @@ And answer according to the language of the user's question.
) )
# get llm prompt # get llm prompt
old_prompt_messages, _ = cls.get_main_llm_prompt( old_prompt_messages, _ = final_model_instance.get_prompt(
mode="completion", mode='completion',
model=app_model_config.model_dict,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs, inputs=message.inputs,
agent_execute_result=None, query=message.query,
context=None,
memory=None memory=None
) )

View File

@ -1,6 +1,6 @@
import decimal
import json import json
from typing import Optional, Union import time
from typing import Optional, Union, List
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.callback_handler.entity.dataset_query import DatasetQueryObj
@ -15,13 +15,16 @@ from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
MessageChain, DatasetRetrieverResource
class ConversationMessageTask: class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM, inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False): conversation: Optional[Conversation] = None, is_override: bool = False):
self.start_at = time.perf_counter()
self.task_id = task_id self.task_id = task_id
self.app = app self.app = app
@ -41,6 +44,8 @@ class ConversationMessageTask:
self.message = None self.message = None
self.retriever_resource = None
self.model_dict = self.app_model_config.model_dict self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider') self.provider_name = self.model_dict.get('provider')
self.model_name = self.model_dict.get('name') self.model_name = self.model_dict.get('name')
@ -58,19 +63,10 @@ class ConversationMessageTask:
) )
def init(self): def init(self):
override_model_configs = None override_model_configs = None
if self.is_override: if self.is_override:
override_model_configs = { override_model_configs = self.app_model_config.to_dict()
"model": self.app_model_config.model_dict,
"pre_prompt": self.app_model_config.pre_prompt,
"agent_mode": self.app_model_config.agent_mode_dict,
"opening_statement": self.app_model_config.opening_statement,
"suggested_questions": self.app_model_config.suggested_questions_list,
"suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
"more_like_this": self.app_model_config.more_like_this_dict,
"sensitive_word_avoidance": self.app_model_config.sensitive_word_avoidance_dict,
"user_input_form": self.app_model_config.user_input_form_list,
}
introduction = '' introduction = ''
system_instruction = '' system_instruction = ''
@ -98,7 +94,7 @@ class ConversationMessageTask:
if not self.conversation: if not self.conversation:
self.is_new_conversation = True self.is_new_conversation = True
self.conversation = Conversation( self.conversation = Conversation(
app_id=self.app_model_config.app_id, app_id=self.app.id,
app_model_config_id=self.app_model_config.id, app_model_config_id=self.app_model_config.id,
model_provider=self.provider_name, model_provider=self.provider_name,
model_id=self.model_name, model_id=self.model_name,
@ -116,10 +112,10 @@ class ConversationMessageTask:
) )
db.session.add(self.conversation) db.session.add(self.conversation)
db.session.flush() db.session.commit()
self.message = Message( self.message = Message(
app_id=self.app_model_config.app_id, app_id=self.app.id,
model_provider=self.provider_name, model_provider=self.provider_name,
model_id=self.model_name, model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
@ -129,9 +125,11 @@ class ConversationMessageTask:
message="", message="",
message_tokens=0, message_tokens=0,
message_unit_price=0, message_unit_price=0,
message_price_unit=0,
answer="", answer="",
answer_tokens=0, answer_tokens=0,
answer_unit_price=0, answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0, provider_response_latency=0,
total_price=0, total_price=0,
currency=self.model_instance.get_currency(), currency=self.model_instance.get_currency(),
@ -142,26 +140,35 @@ class ConversationMessageTask:
) )
db.session.add(self.message) db.session.add(self.message)
db.session.flush() db.session.commit()
def append_message_text(self, text: str): def append_message_text(self, text: str):
self._pub_handler.pub_text(text) if text is not None:
self._pub_handler.pub_text(text)
def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
self.message.message = llm_message.prompt self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price self.message.message_unit_price = message_unit_price
self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else '' self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price self.message.answer_unit_price = answer_unit_price
self.message.provider_response_latency = llm_message.latency self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = time.perf_counter() - self.start_at
self.message.total_price = total_price self.message.total_price = total_price
db.session.commit() db.session.commit()
@ -184,12 +191,13 @@ class ConversationMessageTask:
) )
db.session.add(message_chain) db.session.add(message_chain)
db.session.flush() db.session.commit()
return message_chain return message_chain
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult): def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
message_chain.output = json.dumps(chain_result.completion) message_chain.output = json.dumps(chain_result.completion)
db.session.commit()
self._pub_handler.pub_chain(message_chain) self._pub_handler.pub_chain(message_chain)
@ -202,44 +210,47 @@ class ConversationMessageTask:
tool=agent_loop.tool_name, tool=agent_loop.tool_name,
tool_input=agent_loop.tool_input, tool_input=agent_loop.tool_input,
message=agent_loop.prompt, message=agent_loop.prompt,
message_price_unit=0,
answer=agent_loop.completion, answer=agent_loop.completion,
answer_price_unit=0,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id created_by=self.user.id
) )
db.session.add(message_agent_thought) db.session.add(message_agent_thought)
db.session.flush() db.session.commit()
self._pub_handler.pub_agent_thought(message_agent_thought) self._pub_handler.pub_agent_thought(message_agent_thought)
return message_agent_thought return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop): agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN) agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT) agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens loop_answer_tokens = agent_loop.completion_tokens
loop_total_price = self.calc_total_price( loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_message_tokens, loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
agent_message_unit_price, loop_total_price = loop_message_total_price + loop_answer_total_price
loop_answer_tokens,
agent_answer_unit_price
)
message_agent_thought.observation = agent_loop.tool_output message_agent_thought.observation = agent_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = agent_message_unit_price message_agent_thought.message_unit_price = agent_message_unit_price
message_agent_thought.message_price_unit = agent_message_price_unit
message_agent_thought.answer_token = loop_answer_tokens message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = agent_answer_unit_price message_agent_thought.answer_unit_price = agent_answer_unit_price
message_agent_thought.answer_price_unit = agent_answer_price_unit
message_agent_thought.latency = agent_loop.latency message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = agent_model_instant.get_currency() message_agent_thought.currency = agent_model_instance.get_currency()
db.session.flush() db.session.commit()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
dataset_query = DatasetQuery( dataset_query = DatasetQuery(
@ -252,17 +263,38 @@ class ConversationMessageTask:
) )
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit()
def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price): def on_dataset_query_finish(self, resource: List):
message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'), if resource and len(resource) > 0:
rounding=decimal.ROUND_HALF_UP) for item in resource:
answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'), dataset_retriever_resource = DatasetRetrieverResource(
rounding=decimal.ROUND_HALF_UP) message_id=self.message.id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self.user.id
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self.retriever_resource = resource
total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price def message_end(self):
return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) self._pub_handler.pub_message_end(self.retriever_resource)
def end(self): def end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
self._pub_handler.pub_end() self._pub_handler.pub_end()
@ -356,6 +388,23 @@ class PubHandler:
self.pub_end() self.pub_end()
raise ConversationTaskStoppedException() raise ConversationTaskStoppedException()
def pub_message_end(self, retriever_resource: List):
content = {
'event': 'message_end',
'data': {
'task_id': self._task_id,
'message_id': self._message.id,
'mode': self._conversation.mode,
'conversation_id': self._conversation.id
}
}
if retriever_resource:
content['data']['retriever_resources'] = retriever_resource
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_end(self): def pub_end(self):
content = { content = {

Some files were not shown because too many files have changed in this diff Show More