Compare commits

..

206 Commits

Author SHA1 Message Date
909cc5a49f redis iam check 2024-06-14 03:07:17 +08:00
cdc08a434f feat: support Chroma vector store (#5015) 2024-06-13 18:02:18 +08:00
3f18369ad2 Fix: google storage init with sa and download (#5054) 2024-06-13 17:36:34 +08:00
db976a1f74 Upgrade boto3 library to support EKS Pod Identity. (#5064) 2024-06-13 17:36:14 +08:00
e61f5d029a chore(docs): fix minor small typos (#5124) 2024-06-13 17:36:01 +08:00
eaca892c4e fix: front end error when same tool is called twice at once (#5068) 2024-06-13 17:16:59 +08:00
015c26d303 fix: style misalignment and inconsistency (#5149) 2024-06-13 16:32:42 +08:00
8210637bc5 feat: support jina-clip-v1 embedding model (#5146) 2024-06-13 16:31:18 +08:00
790543131a chore:add some new api version for azure openai (#5142) 2024-06-13 16:30:47 +08:00
a40f68cf94 chore: update qdrant_vector.py (#5128) 2024-06-13 15:35:14 +08:00
adc948e87c fix(api/core/model_runtime/model_providers/baichuan,localai): Parse ToolPromptMessage. #4943 (#5138)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-06-13 13:08:30 +08:00
742b08e1d5 chore: update question classifier prompt (#5137)
Signed-off-by: 0xff-dev <stevenshuang521@gmail.com>
2024-06-13 13:04:51 +08:00
79e8489942 feat: support siliconflow (#5129) 2024-06-13 12:59:41 +08:00
d6fa130cb5 remove dalle3 seed (#5136)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-06-13 08:05:55 +08:00
0c92f81efc chore: sync pyproject.toml from requirements.txt (#5130) 2024-06-13 00:06:05 +08:00
11fd4a5dcc Fix: fix load_yaml logging, Avoid setting the log level to warning (#5019)
Co-authored-by: huangyusong <huangyusong@yingzi.com>
2024-06-12 19:27:01 +08:00
b399e8a359 fixed a typo and grammar error in sampled app (#5061) 2024-06-12 18:02:22 +08:00
e04fc9b304 fix: select field not work when it is not required (#5101) 2024-06-12 17:46:53 +08:00
ea69dc2a7e feat: support hunyuan llm models (#5013)
Co-authored-by: takatost <takatost@users.noreply.github.com>
Co-authored-by: Bowen Liang <bowenliang@apache.org>
2024-06-12 17:24:23 +08:00
ecc7f130b4 fix(typo): misspelling (#5094) 2024-06-12 17:01:21 +08:00
95443bd551 chore: workflow syncing modal (#5108) 2024-06-12 16:35:19 +08:00
0ce97e6315 feat: support doubao llm function calling (#5100) 2024-06-12 15:43:50 +08:00
25b0a97851 build: use Poetry as default build system for dependency installation in CI jobs (#5088) 2024-06-12 14:43:03 +08:00
28997772a5 fix: remote_url doesn't work for gemini (#5090) 2024-06-12 13:14:53 +08:00
b7c72f7a97 dalle3 add style consistency parameter (#5067)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-06-12 12:59:03 +08:00
9f7b38c068 fix: #4970 (#5093) 2024-06-12 11:29:38 +08:00
3b36ba797f feat: add duckduckgo img search, translate, ai chat (#5074) 2024-06-12 10:04:10 +08:00
4d2e6c3391 fix: Google HL Parameter for SearchApi (#5071) 2024-06-12 08:30:01 +08:00
3520d35f38 fix: autoHeightTextarea dimensions in Firefox (#4891) 2024-06-12 08:24:58 +08:00
5f104bab57 Fix: infinite loading not work when message is too short (#5075) 2024-06-12 08:23:39 +08:00
2050a8b8f0 feat: add glm4 new models and zhipu embedding-2 (#5089) 2024-06-12 08:22:17 +08:00
e3544c6ef7 fix: dependency package versions are not synchronized to requirements.txt (#5084) 2024-06-11 22:21:18 +08:00
472b976946 Arabic README.md (#5078) 2024-06-11 18:43:54 +08:00
f62f71a81a build: initial support for poetry build tool (#4513)
Co-authored-by: Bowen Liang <bowenliang@apache.org>
2024-06-11 13:11:28 +08:00
f426e1b3bd 🔧 Fix(docker/volumes/ssrf_proxy/squid.conf): The squid process on ssrf_proxy docker service crashes at startup (#5050) 2024-06-11 12:32:05 +08:00
5f870ac950 chore: update maas model provider description (#5056) 2024-06-11 11:22:22 +08:00
415816cf35 feat: add dataset delete endpoint (#5048) 2024-06-11 11:21:38 +08:00
9103112555 fix: wrong link to web app repo in chatflow mode (#5062) 2024-06-11 11:20:52 +08:00
5986841e27 fix: issue where an error occurs when invoking TTS without selecting a voice (#5046) 2024-06-09 20:28:24 +08:00
2573b138bf fix: update presence_penalty configuration for wenxin AI ernie-4.0-8k and ernie-3.5-8k models (#5039) 2024-06-09 14:44:11 +08:00
308ce66af5 🔧 fix docker-compose ssrf_proxy service WARNING: You should probably remove '::/0' from the ACL named 'all' (#5005) 2024-06-09 14:39:52 +08:00
bdad993901 improve: generalize vector factory classes and vector type (#5033) 2024-06-08 22:29:24 +08:00
3b62ab564a feat: feature modal style (#5032) 2024-06-08 07:32:34 +08:00
d319d9fc5e fix(style): some style issues (#5029) 2024-06-07 20:59:39 +08:00
ea5c8a72e2 Fix language setting not success (#5023) 2024-06-07 20:02:08 +08:00
3b60c28b3a deal the external image when extract docx image (#5024) 2024-06-07 20:00:39 +08:00
ea0219a5d5 Fix: z-index in header (#5017) 2024-06-07 16:01:33 +08:00
481e7bc6b9 Fix/azure blob new version (#5004) 2024-06-06 23:36:13 +08:00
1ccba85c91 fix: modal z-index and cleanup (#4978) 2024-06-06 22:28:13 +08:00
2539e56514 fix: some base models cannot be selected in Azure OpenAI Service setting page (#4985) 2024-06-06 22:27:57 +08:00
3929d289e0 feat: set default memory messages limit to infinite (#5002) 2024-06-06 17:39:44 +08:00
52585aea74 fix: typo in sd3 (#5000) 2024-06-06 17:08:49 +08:00
73dee84cab fix: add handling for non-string type in variable template parser (#4996) 2024-06-06 16:38:13 +08:00
efecdccf35 feat: support login by given mail (#4991) 2024-06-06 15:01:58 +08:00
da5f2e168a fix: llm selector position is incorrect in not workflow app (#4982) 2024-06-06 10:47:36 +08:00
Joe
5cdb95be1f fix: gemini timeout error (#4955) 2024-06-06 10:19:03 +08:00
7fa735a43b chore: rename vdb tests for PGVector and PGvectoRS (#4973) 2024-06-06 07:22:49 +08:00
3579fd1b09 feat: add create tenant command (#4974) 2024-06-06 00:42:00 +08:00
237b8fe3d9 add meta.doc_id index for tidb (#4963) 2024-06-05 20:45:43 +08:00
02e4de5166 fix some tidb bugs (#4960) 2024-06-05 19:14:18 +08:00
64c8093c1e Typo in Knowledge settings (#4958) 2024-06-05 18:31:24 +08:00
0797f9bc05 feat: support tidb vector (#4588) 2024-06-05 18:19:53 +08:00
602c4e51ec fix: duckduckgo search does not work (#4949)
Co-authored-by: Jyong <76649700+johnjyong@users.noreply.github.com>
2024-06-05 17:33:58 +08:00
YC
9f8ca75a81 fixing a bug of handling header row when parsing xls file, and tune xls/xlsx parsing result to be more structured (#3600) 2024-06-05 15:28:43 +08:00
80a87f36ea fix: missing iterator in task pipeline (#4948) 2024-06-05 15:10:20 +08:00
63addc9258 fix: missing dataset patch parameters in settings modal (#4901) 2024-06-05 14:21:59 +08:00
f32b440c4a chore: fix indention violations by applying E111 to E117 ruff rules (#4925) 2024-06-05 14:05:15 +08:00
6b6afb7708 fix: import error in web/app/components/header/account-setting/model-provider-page/declarations.ts (#4944) 2024-06-05 14:01:12 +08:00
a4041cb40b fix: end node limit in next step (#4945) 2024-06-05 14:00:47 +08:00
7749b71fff Optimize knowledge retrieval performance by batching dataset quries. (#4917) 2024-06-05 13:30:32 +08:00
3006124e6d feat: pricing page add llm load balancing info (#4942) 2024-06-05 11:31:44 +08:00
3d276f4a7f change "Import from text file" to "Import from file" (#4935) 2024-06-05 09:29:29 +08:00
b20d173324 pref: optimize feature model_load_balancing_enabled value fetch speed… (#4933) 2024-06-05 02:06:19 +08:00
f44d1e62d2 fix: bedrock get_num_tokens prompt_messages parameter name err (#4932) 2024-06-05 01:53:05 +08:00
21ac2afb3a fix: question classifier instruction npe (#4931) 2024-06-05 01:27:58 +08:00
f7dd327bc2 version to 0.6.10 (#4929) 2024-06-05 01:12:20 +08:00
09298a32e7 fix: vanna CVE-2024-5565 by disable visualize of ask func (#4930) 2024-06-05 00:46:22 +08:00
37f292ea91 feat: model load balancing (#4926) 2024-06-05 00:13:29 +08:00
d1dbbc1e33 feat: backend model load balancing support (#4927) 2024-06-05 00:13:04 +08:00
52ec152dd3 fix: incorrect parameters transforming while validating (#4928) 2024-06-05 00:01:30 +08:00
c7bddb637b support instruction in classifier node (#4913) 2024-06-04 20:07:54 +08:00
4e3b0c5aea support rename document (#4915) 2024-06-04 20:07:40 +08:00
b6631cd878 modify rerank and splitter code directory (#4924) 2024-06-04 20:07:25 +08:00
c212700341 fix: router replace in Explore page (#4918) 2024-06-04 19:41:54 +08:00
e121788ff5 chore: make the error msg more clear when validate app token (#4919)
Co-authored-by: Jyong <76649700+johnjyong@users.noreply.github.com>
2024-06-04 18:04:10 +08:00
96460d5ea3 feat: document support rename in in dataset (#4732) 2024-06-04 15:10:34 +08:00
9cf9720efa Fix/azure blob token expire (#4914) 2024-06-04 14:30:23 +08:00
2d9f55b632 feat: Add Vanna.AI as a builtin tool (#4878)
Co-authored-by: Yeuoly <admin@srmxy.cn>
2024-06-04 14:05:29 +08:00
7133a16511 chore: refactor the serpapi's google search tool (#4834) 2024-06-04 14:05:05 +08:00
a38dfc006e feat: question classify node support use var in instruction (#4710) 2024-06-04 14:01:40 +08:00
86e7c7321f Fixed a bug where any content in the 'fetch' was converted to True (#4400) 2024-06-04 13:27:23 +08:00
58db719a2c dep: bump pandas from 1.x to 2.x (#4820) 2024-06-04 13:24:28 +08:00
9abeb99b32 chore: modify tools/JinaReader label to Jina (#4908) 2024-06-04 13:21:05 +08:00
d828a7fc35 fix azure blob token expire (#4911) 2024-06-04 13:04:56 +08:00
c6f9ea4434 chore: update page.tsx (#4897) 2024-06-04 10:19:49 +08:00
fb6843815c chore: separate style checks into multiple jobs triggering on file changes (#4876) 2024-06-04 03:03:18 +08:00
b97181a793 chore(deps): bump azure-storage-blob from 12.9.0 to 12.13.0 in /api (#4695)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-06-04 02:57:33 +08:00
5d15aca85f chore: remove unused code and class in text splitter (#4864) 2024-06-04 02:54:09 +08:00
b98a1a3303 feat: added Anthropic Claude3 models to Google Cloud Vertex AI (#4870)
Co-authored-by: pwm <pwm@google.com>
2024-06-04 02:52:46 +08:00
696c5308a9 chore: optimize nvidia nim credential schema and info (#4898) 2024-06-04 02:26:26 +08:00
3542d55e67 improve: generalize tool parameter converter (#4786) 2024-06-03 21:26:58 +08:00
3c8a120e51 add-nvidia-mim (#4882) 2024-06-03 21:10:18 +08:00
cd24308f20 chore: add issue link tempate for IDEA (#4866) 2024-06-03 13:39:54 +08:00
69190e088e fix: update npm version to fix Incorrect argument types in createChatMessage (#4865) 2024-06-03 08:22:27 +08:00
d058a234ba Fixed workflow tts feature audition (#4867) 2024-06-03 00:22:14 +08:00
41e536109b fix: Incorrect argument types in createChatMessage (#4861) 2024-06-02 21:03:42 +08:00
f916aa0f92 chore: upgrade sandbox (#4839) 2024-06-02 11:30:14 +08:00
cdbc260571 Bugfix: Vertex AI vision model not support image (#4853) 2024-06-02 11:11:09 +08:00
b234710af9 chore: fix invalid escape sequences by applying W605 rule (#4851) 2024-06-02 10:02:37 +08:00
23498883d4 chore: skip explicit installing jinja2 as testing dependency (#4845) 2024-06-02 09:49:20 +08:00
a47e8d0da2 test: CI test for db migration scripts on changes (#4739) 2024-05-31 16:45:34 +08:00
6dd0e07af8 test: triggering tests on changes and allow cancelling in-progress CI test jobs (#4743) 2024-05-31 16:42:14 +08:00
b1c9671a60 fix: status query not stop when leaving document embedding detail page (#4754) 2024-05-31 16:07:48 +08:00
7aaa1ff270 chore: increase workflow max steps to 500 (#4835) 2024-05-31 15:16:35 +08:00
85698ca4f7 chore: cleanup tools, remove useless code (#4833) 2024-05-31 14:19:59 +08:00
176d91937d fix 'NoneType' and new ContentType supported. (#4818) 2024-05-31 14:19:33 +08:00
e0da0744b5 add: ollama keep alive parameter added. issue #4024 (#4655) 2024-05-31 12:22:02 +08:00
0b4902bdc2 fix: workflow app run (#4831) 2024-05-31 12:15:25 +08:00
e9904e66e6 chore: Enable case-insensitive search for large models (#4817) 2024-05-31 08:55:37 +08:00
3de8e8fd6a Feat/i18n workflow (#4819) 2024-05-30 21:03:32 +08:00
38a470a873 fix: app_count of dataset is error when apps was deleted (#4810) 2024-05-30 19:23:46 +08:00
4308a79e89 fix: revision styles for workflow (#4087) 2024-05-30 19:10:14 +08:00
93d3350c8c update sd-webui api parameters to v1.9.3 (#4798)
Co-authored-by: Your Name <chen@krasus.red>
2024-05-30 19:04:47 +08:00
615c009c42 fix: remove redundant props (#4787) 2024-05-30 18:58:08 +08:00
a325a294bd feat: opportunistic tls flag for smtp (#4794) 2024-05-30 18:56:46 +08:00
4b91383efc feat: workflow variable aggregator support group (#4811)
Co-authored-by: Yeuoly <admin@srmxy.cn>
2024-05-30 18:54:58 +08:00
18ab63bd37 add: i18n: update korean (#4813) 2024-05-30 17:40:35 +08:00
a7fb1ffcd8 feat: show more usage info in billing page (#4808) 2024-05-30 16:15:38 +08:00
11f173693b fix: some filed in model param selector has no left spacing (#4803) 2024-05-30 14:49:41 +08:00
5b2cd8d03a chore: node help link (#4795) 2024-05-30 14:24:53 +08:00
b10e67be3b Add SearchApi tools (#4648) 2024-05-30 11:11:17 +08:00
d41c077fac chore: improve node user experience (#4792) 2024-05-30 10:53:02 +08:00
3175a2c76a fix: in tool and http node of iteration can not show item var correctly (#4791) 2024-05-30 10:40:27 +08:00
3b60b712ec feat: Add logging warning when MAIL_TYPE is not set (#4771) 2024-05-29 18:06:16 +08:00
afed3610fc fix organize agent's history messages without recalculating tokens (#4324)
Co-authored-by: chenyongzhao <chenyz@mama.cn>
2024-05-29 15:25:20 +08:00
74f38eacda feat: support define tags in tool yaml (#4763) 2024-05-29 15:19:14 +08:00
b189faca52 feat: update ernie model (#4756) 2024-05-29 14:57:23 +08:00
d4cd6149ac fix: incorrect workflow max call depth (#4759) 2024-05-29 14:52:28 +08:00
e1cd9aef8f feat: support baichuan3 turbo, baichuan3 turbo 128k, and baichuan4 (#4762) 2024-05-29 14:46:04 +08:00
ba37275503 fix: confusing chart description (#4760) 2024-05-29 14:36:33 +08:00
e01b44af61 style: fix annotation panel display misalignment (#4750) 2024-05-29 14:23:44 +08:00
72a90074bc Add WORKFLOW_CALL_MAX_DEPTH env var. (#4713) 2024-05-29 13:39:11 +08:00
705a6e3a8e Fix/4742 ollama num gpu option not consistent with allowed values (#4751) 2024-05-29 13:33:35 +08:00
f4a240d225 style: the 'all' of add tool panel should contain workflow tools (#4755) 2024-05-29 13:04:23 +08:00
793f0c1dd6 fix: Corrected schema link in model_runtime's README.md (#4757) 2024-05-29 13:03:21 +08:00
008edd0eeb fix: optimize sticky header styles z-index in tools - ProviderList component (#4746) 2024-05-29 08:36:11 +08:00
9e6b6e7b82 fix: workflow run sequence number slow sql (#4737) 2024-05-28 20:41:52 +08:00
164d6e47b9 Show tool i18n name on chat pannel (#4724) 2024-05-28 18:58:02 +08:00
88b4d69278 fix: Correct context size for banchuan2-53b and banchuan2-turbo (#4721) 2024-05-28 16:37:44 +08:00
5bcbcd3c57 fix: retrieval value greater more than 1 caused ui problem (#4718) 2024-05-28 16:01:19 +08:00
1b2d862973 add error msg for hit test (#4704) 2024-05-28 14:54:53 +08:00
e6f6a59f3b style: update VarPanel to use whitespace-pre-wrap for value display (#4684) 2024-05-28 14:54:29 +08:00
e198bc9b9a fix: workflow as tool garbled (#4707) 2024-05-28 14:51:42 +08:00
b7f81f0999 fix: the new node name is generated based on the original node when duplicating (#4675) 2024-05-28 13:50:43 +08:00
eb8dc15ad6 fix: Input fields in the model provider's settings modal do not switch sequence via keyboard navigation (Tab key) (#4662) 2024-05-28 11:34:44 +08:00
2ee3a1b6f3 fix: key-value-table styles (#4678) 2024-05-28 10:57:40 +08:00
0960b17fbc Add workflow translations for ja-JP (#4698)
Co-authored-by: crazywoola <427733928@qq.com>
2024-05-28 10:27:35 +08:00
6534566b7e feat: add América/São Paulo tz (#4701) 2024-05-28 10:12:18 +08:00
e60350d95d version to 0.6.9 (#4692) 2024-05-27 22:48:34 +08:00
f40743183e 🔧 Add env variable for time signature (#4650) 2024-05-27 22:20:49 +08:00
e852a21634 Feat/workflow phase2 (#4687) 2024-05-27 22:01:11 +08:00
45deaee762 feat: workflow new nodes (#4683)
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Patryk Garstecki <patryk20120@yahoo.pl>
Co-authored-by: Sebastian.W <thiner@gmail.com>
Co-authored-by: 呆萌闷油瓶 <253605712@qq.com>
Co-authored-by: takatost <takatost@users.noreply.github.com>
Co-authored-by: rechardwang <wh_goodjob@163.com>
Co-authored-by: Nite Knite <nkCoding@gmail.com>
Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
Co-authored-by: Joshua <138381132+joshua20231026@users.noreply.github.com>
Co-authored-by: Weaxs <459312872@qq.com>
Co-authored-by: Ikko Eltociear Ashimine <eltociear@gmail.com>
Co-authored-by: leejoo0 <81673835+leejoo0@users.noreply.github.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: sino <sino2322@gmail.com>
Co-authored-by: Vikey Chen <vikeytk@gmail.com>
Co-authored-by: wanghl <Wang-HL@users.noreply.github.com>
Co-authored-by: Haolin Wang-汪皓临 <haolin.wang@atlaslovestravel.com>
Co-authored-by: Zixuan Cheng <61724187+Theysua@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Bowen Liang <bowenliang@apache.org>
Co-authored-by: Bowen Liang <liangbowen@gf.com.cn>
Co-authored-by: fanghongtai <42790567+fanghongtai@users.noreply.github.com>
Co-authored-by: wxfanghongtai <wxfanghongtai@gf.com.cn>
Co-authored-by: Matri <qjp@bithuman.io>
Co-authored-by: Benjamin <benjaminx@gmail.com>
2024-05-27 21:57:08 +08:00
444fdb79dc fix typo: stopParameerRule -> stopParameterRule (#4681) 2024-05-27 20:42:07 +08:00
140dd873f1 fix: show exception message when sandbox execution fails (#4663) 2024-05-27 18:06:15 +08:00
27dae156db fix: colon in file mistral.mistral-small-2402-v1:0 (#4673) 2024-05-27 13:15:20 +08:00
2deb23e00e fix: Show rerank in system for localai (#4652) 2024-05-27 12:09:51 +08:00
8152bc6fbf Feat/upgrade check i18n scripts (#4671) 2024-05-27 10:36:34 +08:00
af026c5953 fix: node.js sdk if request is a get data must not exist (#4618) 2024-05-27 08:48:07 +08:00
11275cbaaf fix: z-index (#4065) 2024-05-27 08:47:27 +08:00
fe9bf5fc4a [seanguo] add support of amazon titan v2 and modify the price of amazon titan v1 (#4643)
Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
2024-05-26 23:30:22 +08:00
cd4924d472 Fix tts audition (#4656)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-05-26 21:00:36 +08:00
f56b984d97 Fix Unnecessary Newline Characters in Extracted Tool Response Text (#4646)
Co-authored-by: kronus <kronus@istarshine.com>
2024-05-25 15:24:59 +08:00
f804adbff3 feat: Support for Vertex AI - load Default Application Configuration (#4641)
Co-authored-by: miendinh <miendinh@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-05-25 13:40:25 +08:00
109aabc6f2 fix: incorrect handling when http header value contain multiple colons. (#4574) 2024-05-25 13:20:06 +08:00
ad620f02c7 fix: the date is incorrect if the db field is timestamp and the TZ is not the UTC (#4624)
Co-authored-by: liuzhenghua-jk <liuzhenghua-jk@360shuke.com>
2024-05-25 13:11:18 +08:00
5bd432a85f Fix tts audition (#4637)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-05-25 12:55:04 +08:00
026175c8f7 feat: update notion extractor (#3898)
Co-authored-by: duyalei <>
2024-05-24 20:30:48 +08:00
f156014daa update lite8k/speed8k/128k max_token to newest (#4636)
Co-authored-by: Your Name <chen@krasus.red>
2024-05-24 19:33:42 +08:00
2cd55e456a fix: WORKFLOW_MAX_EXECUTION_STEPS spell error in config.py (#4642) 2024-05-24 16:32:26 +08:00
8c2ca60c8b feat: Add WORKFLOW_MAX_EXECUTION_TIME env var (#4632) 2024-05-24 15:27:12 +08:00
3fda2245a4 improve: extract method for safe loading yaml file and avoid using PyYaml's FullLoader (#4031) 2024-05-24 12:08:12 +08:00
296887754f Support for Vertex AI (#4586) 2024-05-24 12:01:40 +08:00
9ae72cdcf4 feat: Add Gemini Flash (#4616) 2024-05-24 11:43:06 +08:00
16fec084f5 Fix/4630 bug api suggested (#4633) 2024-05-24 11:11:12 +08:00
10c61da686 feat: add confirm ui (#4625) 2024-05-23 20:15:51 +08:00
24624491cd add qdrant metadata.doc_id index when create qdrant collection (#4570) 2024-05-23 18:11:01 +08:00
233c4150d1 support images and tables extract from docx (#4619) 2024-05-23 18:05:23 +08:00
5893ebec55 fix: code node garbled in Javascript (#4615) 2024-05-23 17:18:57 +08:00
11642192d1 chore: add https://api.openai.com placeholder in OpenAI api base (#4604) 2024-05-23 12:56:05 +08:00
e57bdd4e58 chore:update gpt-3.5-turbo and gpt-4-turbo parameter for azure (#4596) 2024-05-23 11:51:38 +08:00
461488e9bf Add Azure OpenAI API version for GPT4o support (#4569)
Co-authored-by: wwwc <wwwc@outlook.com>
2024-05-22 17:43:16 +08:00
2988b67c24 fix: hide automatic button on automatic result page (#4494) 2024-05-22 16:44:20 +08:00
4f62541bfb chore: remove model provider free token link (#4579) 2024-05-22 16:42:49 +08:00
24576a39e5 fix: some google search result raise exception (#4567) 2024-05-22 14:28:52 +08:00
3ab19be9ea Fix bedrock claude wrong pricing (#4572)
Co-authored-by: Justin Wu <justin.wu@ringcentral.com>
2024-05-22 14:28:28 +08:00
5b009a5afb chore(api): Use channel from UI as API query parameter (#4562) 2024-05-22 14:28:03 +08:00
3efb5fe7e2 Refactor part of the ProviderManager code to improve readability (#4524) 2024-05-22 11:18:03 +08:00
ee53f98d8c Hide the copy button when there is no content to copy (#4546) 2024-05-22 11:15:13 +08:00
d5a33a0323 feat:add gpt-4o for azure (#4568) 2024-05-22 11:02:43 +08:00
f25927855e add qdrant metadata.doc_id index (#4559) 2024-05-22 01:42:08 +08:00
c873035084 oauth2 supports. (#4551) 2024-05-21 17:52:41 +08:00
f32bba6531 Update requirements.txt with latest OSS package with AuthV4 support (#4425) 2024-05-20 17:26:07 +08:00
40bc936739 chore: update yfinance dependency to version 0.2.40 (#4517) 2024-05-20 16:47:05 +08:00
6b5685ef0c feat: Jina Search & Jina Reader CSS selectors (#4523) 2024-05-20 16:40:46 +08:00
e8e213ad1e chore: apply and fix flake8-bugbear lint rules (#4496) 2024-05-20 16:34:13 +08:00
5f4df34829 improve: generalize transformations and scripts of runner and preloads into TemplateTransformer (#4487) 2024-05-20 15:56:26 +08:00
860 changed files with 41335 additions and 7264 deletions

View File

@ -4,6 +4,13 @@ on:
pull_request:
branches:
- main
paths:
- api/**
- docker/**
concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
@ -51,7 +58,7 @@ jobs:
- name: Run Workflow
run: dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
@ -60,6 +67,7 @@ jobs:
docker/docker-compose.milvus.yaml
docker/docker-compose.pgvecto-rs.yaml
docker/docker-compose.pgvector.yaml
docker/docker-compose.chroma.yaml
services: |
weaviate
qdrant
@ -68,6 +76,82 @@ jobs:
milvus-standalone
pgvecto-rs
pgvector
chroma
- name: Test Vector Stores
run: dev/pytest/pytest_vdb.sh
test-in-poetry:
name: API Tests
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.10"
- "3.11"
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'
cache-dependency-path: |
api/pyproject.toml
api/poetry.lock
- name: Poetry check
run: poetry check -C api
- name: Install dependencies
run: poetry install -C api --with dev
- name: Run Unit tests
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
- name: Run ModelRuntime
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
- name: Run Tool
run: poetry run -C api bash dev/pytest/pytest_tools.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
sandbox
ssrf_proxy
- name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
docker/docker-compose.qdrant.yaml
docker/docker-compose.milvus.yaml
docker/docker-compose.pgvecto-rs.yaml
docker/docker-compose.pgvector.yaml
docker/docker-compose.chroma.yaml
services: |
weaviate
qdrant
etcd
minio
milvus-standalone
pgvecto-rs
pgvector
chroma
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

57
.github/workflows/db-migration-test.yml vendored Normal file
View File

@ -0,0 +1,57 @@
name: DB Migration Test
on:
pull_request:
branches:
- main
paths:
- api/migrations/**
concurrency:
group: db-migration-test-${{ github.ref }}
cancel-in-progress: true
jobs:
db-migration-test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.10"
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'
cache-dependency-path: |
api/pyproject.toml
api/poetry.lock
- name: Install dependencies
run: poetry install -C api
- name: Set up Middleware
uses: hoverkraft-tech/compose-action@v2.0.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db
- name: Prepare configs
run: |
cd api
cp .env.example .env
- name: Run DB Migration
run: |
cd api
poetry run python -m flask db upgrade

View File

@ -6,7 +6,7 @@ on:
- main
concurrency:
group: dep-${{ github.head_ref || github.run_id }}
group: style-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
@ -18,54 +18,92 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v44
with:
files: api/**
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python
uses: actions/setup-python@v5
if: steps.changed-files.outputs.any_changed == 'true'
with:
python-version: '3.10'
- name: Python dependencies
run: pip install ruff dotenv-linter
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry install -C api --only lint
- name: Ruff check
run: ruff check ./api
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api ruff check --preview ./api
- name: Dotenv check
run: dotenv-linter ./api/.env.example ./web/.env.example
if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
test:
name: ESLint and SuperLinter
web-style:
name: Web Style
runs-on: ubuntu-latest
needs: python-style
defaults:
run:
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v44
with:
fetch-depth: 0
files: web/**
- name: Setup NodeJS
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 20
cache: yarn
cache-dependency-path: ./web/package.json
- name: Web dependencies
run: |
cd ./web
yarn install --frozen-lockfile
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn install --frozen-lockfile
- name: Web style check
run: |
cd ./web
yarn run lint
if: steps.changed-files.outputs.any_changed == 'true'
run: yarn run lint
superlinter:
name: SuperLinter
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v44
with:
files: |
**.sh
**.yaml
**.yml
Dockerfile
- name: Super-linter
uses: super-linter/super-linter/slim@v6
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning
DEFAULT_BRANCH: main
@ -76,4 +114,5 @@ jobs:
VALIDATE_BASH_EXEC: true
VALIDATE_GITHUB_ACTIONS: true
VALIDATE_DOCKERFILE_HADOLINT: true
VALIDATE_XML: true
VALIDATE_YAML: true

View File

@ -4,6 +4,13 @@ on:
pull_request:
branches:
- main
paths:
- sdks/**
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
build:
name: unit test for Node.js SDK

4
.gitignore vendored
View File

@ -134,7 +134,8 @@ dmypy.json
web/.vscode/settings.json
# Intellij IDEA Files
.idea/
.idea/*
!.idea/vcs.xml
.ideaDataSources/
api/.env
@ -148,6 +149,7 @@ docker/volumes/qdrant/*
docker/volumes/etcd/*
docker/volumes/minio/*
docker/volumes/milvus/*
docker/volumes/chroma/*
sdks/python-client/build
sdks/python-client/dist

16
.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="IssueNavigationConfiguration">
<option name="links">
<list>
<IssueNavigationLink>
<option name="issueRegexp" value="#(\d+)" />
<option name="linkRegexp" value="https://github.com/langgenius/dify/issues/$1" />
</IssueNavigationLink>
</list>
</option>
</component>
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@ -36,6 +36,7 @@
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
</p>

225
README_AR.md Normal file
View File

@ -0,0 +1,225 @@
![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab)
<p align="center">
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
<a href="https://docs.dify.ai/getting-started/install-self-hosted">الاستضافة الذاتية</a> ·
<a href="https://docs.dify.ai">التوثيق</a> ·
<a href="https://cal.com/guchenhe/60-min-meeting">استفسارات الشركات</a>
</p>
<p align="center">
<a href="https://dify.ai" target="_blank">
<img alt="Static Badge" src="https://img.shields.io/badge/Product-F04438"></a>
<a href="https://dify.ai/pricing" target="_blank">
<img alt="Static Badge" src="https://img.shields.io/badge/free-pricing?logo=free&color=%20%23155EEF&label=pricing&labelColor=%20%23528bff"></a>
<a href="https://discord.gg/FngNHpbcY7" target="_blank">
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb"
alt="chat on Discord"></a>
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
alt="follow on Twitter"></a>
<a href="https://hub.docker.com/u/langgenius" target="_blank">
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
<img alt="Commits last month" src="https://img.shields.io/github/commit-activity/m/langgenius/dify?labelColor=%20%2332b583&color=%20%2312b76a"></a>
<a href="https://github.com/langgenius/dify/" target="_blank">
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
</p>
<p align="center">
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
</p>
<div style="text-align: right;">
مشروع Dify هو منصة تطوير تطبيقات الذكاء الصناعي مفتوحة المصدر. تجمع واجهته البديهية بين سير العمل الذكي بالذكاء الاصطناعي وخط أنابيب RAG وقدرات الوكيل وإدارة النماذج وميزات الملاحظة وأكثر من ذلك، مما يتيح لك الانتقال بسرعة من المرحلة التجريبية إلى الإنتاج. إليك قائمة بالميزات الأساسية:
</br> </br>
**1. سير العمل**: قم ببناء واختبار سير عمل الذكاء الاصطناعي القوي على قماش بصري، مستفيدًا من جميع الميزات التالية وأكثر.
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
**2. الدعم الشامل للنماذج**: تكامل سلس مع مئات من LLMs الخاصة / مفتوحة المصدر من عشرات من موفري التحليل والحلول المستضافة ذاتيًا، مما يغطي GPT و Mistral و Llama3 وأي نماذج متوافقة مع واجهة OpenAI API. يمكن العثور على قائمة كاملة بمزودي النموذج المدعومين [هنا](https://docs.dify.ai/getting-started/readme/model-providers).
![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3)
**3. بيئة التطوير للأوامر**: واجهة بيئة التطوير المبتكرة لصياغة الأمر ومقارنة أداء النموذج، وإضافة ميزات إضافية مثل تحويل النص إلى كلام إلى تطبيق قائم على الدردشة.
**4. خط أنابيب RAG**: قدرات RAG الواسعة التي تغطي كل شيء من استيعاب الوثائق إلى الاسترجاع، مع الدعم الفوري لاستخراج النص من ملفات PDF و PPT وتنسيقات الوثائق الشائعة الأخرى.
**5. قدرات الوكيل**: يمكنك تعريف الوكلاء بناءً على أمر وظيفة LLM أو ReAct، وإضافة أدوات مدمجة أو مخصصة للوكيل. توفر Dify أكثر من 50 أداة مدمجة لوكلاء الذكاء الاصطناعي، مثل البحث في Google و DELL·E وStable Diffusion و WolframAlpha.
**6. الـ LLMOps**: راقب وتحلل سجلات التطبيق والأداء على مر الزمن. يمكنك تحسين الأوامر والبيانات والنماذج باستمرار استنادًا إلى البيانات الإنتاجية والتعليقات.
**7.الواجهة الخلفية (Backend) كخدمة**: تأتي جميع عروض Dify مع APIs مطابقة، حتى يمكنك دمج Dify بسهولة في منطق أعمالك الخاص.
## مقارنة الميزات
<table style="width: 100%;">
<tr>
<th align="center">الميزة</th>
<th align="center">Dify.AI</th>
<th align="center">LangChain</th>
<th align="center">Flowise</th>
<th align="center">OpenAI Assistants API</th>
</tr>
<tr>
<td align="center">نهج البرمجة</td>
<td align="center">موجّه لـ تطبيق + واجهة برمجة تطبيق (API)</td>
<td align="center">برمجة Python</td>
<td align="center">موجه لتطبيق</td>
<td align="center">واجهة برمجة تطبيق (API)</td>
</tr>
<tr>
<td align="center">LLMs المدعومة</td>
<td align="center">تنوع غني</td>
<td align="center">تنوع غني</td>
<td align="center">تنوع غني</td>
<td align="center">فقط OpenAI</td>
</tr>
<tr>
<td align="center">محرك RAG</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">الوكيل</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">سير العمل</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">الملاحظة</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">ميزات الشركات (SSO / مراقبة الوصول)</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">نشر محلي</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
</table>
## استخدام Dify
- **سحابة </br>**
نحن نستضيف [خدمة Dify Cloud](https://dify.ai) لأي شخص لتجربتها بدون أي إعدادات. توفر كل قدرات النسخة التي تمت استضافتها ذاتيًا، وتتضمن 200 أمر GPT-4 مجانًا في خطة الصندوق الرملي.
- **استضافة ذاتية لنسخة المجتمع Dify</br>**
ابدأ سريعًا في تشغيل Dify في بيئتك باستخدام [دليل البدء السريع](#البدء السريع).
استخدم [توثيقنا](https://docs.dify.ai) للمزيد من المراجع والتعليمات الأعمق.
- **مشروع Dify للشركات / المؤسسات</br>**
نحن نوفر ميزات إضافية مركزة على الشركات. [جدول اجتماع معنا](https://cal.com/guchenhe/30min) أو [أرسل لنا بريدًا إلكترونيًا](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) لمناقشة احتياجات الشركات. </br>
> بالنسبة للشركات الناشئة والشركات الصغيرة التي تستخدم خدمات AWS، تحقق من [Dify Premium على AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) ونشرها في شبكتك الخاصة على AWS VPC بنقرة واحدة. إنها عرض AMI بأسعار معقولة مع خيار إنشاء تطبيقات بشعار وعلامة تجارية مخصصة.
## البقاء قدمًا
قم بإضافة نجمة إلى Dify على GitHub وتلق تنبيهًا فوريًا بالإصدارات الجديدة.
![نجمنا](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4)
## البداية السريعة
> قبل تثبيت Dify، تأكد من أن جهازك يلبي الحد الأدنى من متطلبات النظام التالية:
>
>- معالج >= 2 نواة
>- ذاكرة وصول عشوائي (RAM) >= 4 جيجابايت
</br>
أسهل طريقة لبدء تشغيل خادم Dify هي تشغيل ملف [docker-compose.yml](docker/docker-compose.yaml) الخاص بنا. قبل تشغيل أمر التثبيت، تأكد من تثبيت [Docker](https://docs.docker.com/get-docker/) و [Docker Compose](https://docs.docker.com/compose/install/) على جهازك:
```bash
cd docker
docker compose up -d
```
بعد التشغيل، يمكنك الوصول إلى لوحة تحكم Dify في متصفحك على [http://localhost/install](http://localhost/install) وبدء عملية التهيئة.
> إذا كنت ترغب في المساهمة في Dify أو القيام بتطوير إضافي، فانظر إلى [دليلنا للنشر من الشفرة (code) المصدرية](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
## الخطوات التالية
إذا كنت بحاجة إلى تخصيص التكوين، يرجى الرجوع إلى التعليقات في ملف [docker-compose.yml](docker/docker-compose.yaml) لدينا وتعيين التكوينات البيئية يدويًا. بعد إجراء التغييرات، يرجى تشغيل `docker-compose up -d` مرة أخرى. يمكنك رؤية قائمة كاملة بالمتغيرات البيئية [هنا](https://docs.dify.ai/getting-started/install-self-hosted/environments).
إذا كنت ترغب في تكوين إعداد متوفر بشكل عالي، فهناك [رسوم بيانية Helm](https://helm.sh/) المساهمة من المجتمع تسمح بنشر Dify على Kubernetes.
- [رسم بياني Helm من قبل @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [رسم بياني Helm من قبل @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
## المساهمة
لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا.
في الوقت نفسه، يرجى النظر في دعم Dify عن طريق مشاركته على وسائل التواصل الاجتماعي وفي الفعاليات والمؤتمرات.
> نحن نبحث عن مساهمين لمساعدة في ترجمة Dify إلى لغات أخرى غير اللغة الصينية المندرين أو الإنجليزية. إذا كنت مهتمًا بالمساعدة، يرجى الاطلاع على [README للترجمة](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) لمزيد من المعلومات، واترك لنا تعليقًا في قناة `global-users` على [خادم المجتمع على Discord](https://discord.gg/8Tpq4AcN9c).
**المساهمون**
<a href="https://github.com/langgenius/dify/graphs/contributors">
<img src="https://contrib.rocks/image?repo=langgenius/dify" />
</a>
## المجتمع والاتصال
* [مناقشة Github](https://github.com/langgenius/dify/discussions). الأفضل لـ: مشاركة التعليقات وطرح الأسئلة.
* [المشكلات على GitHub](https://github.com/langgenius/dify/issues). الأفضل لـ: الأخطاء التي تواجهها في استخدام Dify.AI، واقتراحات الميزات. انظر [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
* [البريد الإلكتروني](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). الأفضل لـ: الأسئلة التي تتعلق باستخدام Dify.AI.
* [Discord](https://discord.gg/FngNHpbcY7). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع.
* [تويتر](https://twitter.com/dify_ai). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع.
أو، قم بجدولة اجتماع مباشرة مع أحد أعضاء الفريق:
<table>
<tr>
<th>نقطة الاتصال</th>
<th>الغرض</th>
</tr>
<tr>
<td><a href='https://cal.com/guchenhe/15min' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/9ebcd111-1205-4d71-83d5-948d70b809f5' alt='Git-Hub-README-Button-3x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
<td>استفسارات الأعمال واقتراحات حول المنتج</td>
</tr>
<tr>
<td><a href='https://cal.com/pinkbanana' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/d1edd00a-d7e4-4513-be6c-e57038e143fd' alt='Git-Hub-README-Button-2x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
<td>المساهمات والمشكلات وطلبات الميزات</td>
</tr>
</table>
## تاريخ النجمة
[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)
## الكشف عن الأمان
لحماية خصوصيتك، يرجى تجنب نشر مشكلات الأمان على GitHub. بدلاً من ذلك، أرسل أسئلتك إلى security@dify.ai وسنقدم لك إجابة أكثر تفصيلاً.
## الرخصة
هذا المستودع متاح تحت [رخصة البرنامج الحر Dify](LICENSE)، والتي تعتبر بشكل أساسي Apache 2.0 مع بعض القيود الإضافية.

View File

@ -17,6 +17,9 @@ APP_WEB_URL=http://127.0.0.1:3000
# Files URL
FILES_URL=http://127.0.0.1:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
@ -109,6 +112,21 @@ PGVECTOR_USER=postgres
PGVECTOR_PASSWORD=postgres
PGVECTOR_DATABASE=postgres
# Tidb Vector configuration
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
TIDB_VECTOR_PORT=4000
TIDB_VECTOR_USER=xxx.root
TIDB_VECTOR_PASSWORD=xxxxxx
TIDB_VECTOR_DATABASE=dify
# Chroma configuration
CHROMA_HOST=127.0.0.1
CHROMA_PORT=8000
CHROMA_TENANT=default_tenant
CHROMA_DATABASE=default_database
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
CHROMA_AUTH_CREDENTIALS=difyai123456
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
@ -124,10 +142,11 @@ RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
# smtp configuration
SMTP_SERVER=smtp.gmail.com
SMTP_PORT=587
SMTP_PORT=465
SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=false
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Sentry configuration
SENTRY_DSN=
@ -179,3 +198,8 @@ LOG_FILE=
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
# Workflow runtime configuration
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5

View File

@ -17,7 +17,8 @@
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},
"console": "integratedTerminal"
"console": "integratedTerminal",
"python": "${command:python.interpreterPath}"
},
{
"name": "Python: Flask",
@ -36,7 +37,8 @@
"--debug"
],
"jinja": true,
"justMyCode": true
"justMyCode": true,
"python": "${command:python.interpreterPath}"
}
]
}

View File

@ -17,15 +17,30 @@
```bash
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
```
4. If you use Anaconda, create a new environment and activate it
4. Create environment.
- Anaconda
If you use Anaconda, create a new environment and activate it
```bash
conda create --name dify python=3.10
conda activate dify
```
- Poetry
If you use Poetry, you don't need to manually create the environment. You can execute `poetry shell` to activate the environment.
5. Install dependencies
- Anaconda
```bash
pip install -r requirements.txt
```
- Poetry
```bash
poetry install
```
In case of contributors missing to update dependencies for `pyproject.toml`, you can perform the following shell instead.
```base
poetry shell # activate current environment
poetry add $(cat requirements.txt) # install dependencies of production and update pyproject.toml
poetry add $(cat requirements-dev.txt) --group dev # install dependencies of development and update pyproject.toml
```
6. Run migrate
Before the first launch, migrate the database to the latest version.

View File

@ -1,12 +1,15 @@
import base64
import json
import secrets
from typing import Optional
import click
from flask import current_app
from werkzeug.exceptions import NotFound
from constants.languages import languages
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_database import db
from libs.helper import email as email_validate
@ -17,6 +20,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
from services.account_service import RegisterService, TenantService
@click.command('reset-password', help='Reset the account password.')
@ -57,7 +61,7 @@ def reset_password(email, new_password, password_confirm):
account.password = base64_password_hashed
account.password_salt = base64_salt
db.session.commit()
click.echo(click.style('Congratulations!, password has been reset.', fg='green'))
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
@click.command('reset-email', help='Reset the account email.')
@ -263,15 +267,15 @@ def migrate_knowledge_vector_database():
skipped_count = skipped_count + 1
continue
collection_name = ''
if vector_type == "weaviate":
if vector_type == VectorType.WEAVIATE:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'weaviate',
"type": VectorType.WEAVIATE,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "qdrant":
elif vector_type == VectorType.QDRANT:
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
@ -284,20 +288,20 @@ def migrate_knowledge_vector_database():
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'qdrant',
"type": VectorType.QDRANT,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "milvus":
elif vector_type == VectorType.MILVUS:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'milvus',
"type": VectorType.MILVUS,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "relyt":
elif vector_type == VectorType.RELYT:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
@ -305,16 +309,16 @@ def migrate_knowledge_vector_database():
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == "pgvector":
elif vector_type == VectorType.PGVECTOR:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'pgvector',
"type": VectorType.PGVECTOR,
"vector_store": {"class_prefix": collection_name}
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
raise ValueError(f"Vector store {vector_type} is not supported.")
vector = Vector(dataset)
click.echo(f"Start to migrate dataset {dataset.id}.")
@ -448,9 +452,105 @@ def convert_to_agent_apps():
click.echo(click.style('Congratulations! Converted {} agent apps.'.format(len(proceeded_app_ids)), fg='green'))
@click.command('add-qdrant-doc-id-index', help='add qdrant doc_id index.')
@click.option('--field', default='metadata.doc_id', prompt=False, help='index field , default is metadata.doc_id.')
def add_qdrant_doc_id_index(field: str):
click.echo(click.style('Start add qdrant doc_id index.', fg='green'))
config = current_app.config
vector_type = config.get('VECTOR_STORE')
if vector_type != "qdrant":
click.echo(click.style('Sorry, only support qdrant vector store.', fg='red'))
return
create_count = 0
try:
bindings = db.session.query(DatasetCollectionBinding).all()
if not bindings:
click.echo(click.style('Sorry, no dataset collection bindings found.', fg='red'))
return
import qdrant_client
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig
for binding in bindings:
qdrant_config = QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
grpc_port=config.get('QDRANT_GRPC_PORT'),
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
)
try:
client = qdrant_client.QdrantClient(**qdrant_config.to_qdrant_params())
# create payload index
client.create_payload_index(binding.collection_name, field,
field_schema=PayloadSchemaType.KEYWORD)
create_count += 1
except UnexpectedResponse as e:
# Collection does not exist, so return
if e.status_code == 404:
click.echo(click.style(f'Collection not found, collection_name:{binding.collection_name}.', fg='red'))
continue
# Some other error occurred, so re-raise the exception
else:
click.echo(click.style(f'Failed to create qdrant index, collection_name:{binding.collection_name}.', fg='red'))
except Exception as e:
click.echo(click.style('Failed to create qdrant client.', fg='red'))
click.echo(
click.style(f'Congratulations! Create {create_count} collection indexes.',
fg='green'))
@click.command('create-tenant', help='Create account and tenant.')
@click.option('--email', prompt=True, help='The email address of the tenant account.')
@click.option('--language', prompt=True, help='Account language, default: en-US.')
def create_tenant(email: str, language: Optional[str] = None):
"""
Create tenant account
"""
if not email:
click.echo(click.style('Sorry, email is required.', fg='red'))
return
# Create account
email = email.strip()
if '@' not in email:
click.echo(click.style('Sorry, invalid email address.', fg='red'))
return
account_name = email.split('@')[0]
if language not in languages:
language = 'en-US'
# generate random password
new_password = secrets.token_urlsafe(16)
# register account
account = RegisterService.register(
email=email,
name=account_name,
password=new_password,
language=language
)
TenantService.create_owner_tenant_if_not_exist(account)
click.echo(click.style('Congratulations! Account and tenant created.\n'
'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(vdb_migrate)
app.cli.add_command(convert_to_agent_apps)
app.cli.add_command(add_qdrant_doc_id_index)
app.cli.add_command(create_tenant)

View File

@ -23,6 +23,7 @@ DEFAULTS = {
'SERVICE_API_URL': 'https://api.dify.ai',
'APP_WEB_URL': 'https://udify.app',
'FILES_URL': '',
'FILES_ACCESS_TIMEOUT': 300,
'S3_ADDRESS_STYLE': 'auto',
'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage',
@ -69,6 +70,7 @@ DEFAULTS = {
'INVITE_EXPIRY_HOURS': 72,
'BILLING_ENABLED': 'False',
'CAN_REPLACE_LOGO': 'False',
'MODEL_LB_ENABLED': 'False',
'ETL_TYPE': 'dify',
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20,
@ -80,6 +82,9 @@ DEFAULTS = {
'INNER_API': 'False',
'ENTERPRISE_ENABLED': 'False',
'INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH': 1000,
'WORKFLOW_MAX_EXECUTION_STEPS': 500,
'WORKFLOW_MAX_EXECUTION_TIME': 1200,
'WORKFLOW_CALL_MAX_DEPTH': 5,
}
@ -110,7 +115,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.6.8"
self.CURRENT_VERSION = "0.6.10"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = get_env('EDITION')
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@ -119,6 +124,7 @@ class Config:
self.LOG_FILE = get_env('LOG_FILE')
self.LOG_FORMAT = get_env('LOG_FORMAT')
self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT')
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
# The backend URL prefix of the console API.
# used to concatenate the login authorization callback or notion integration callback.
@ -141,6 +147,10 @@ class Config:
# Url is signed and has expiration time.
self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
# File Access Time specifies a time interval in seconds for the file to be accessed.
# The default value is 300 seconds.
self.FILES_ACCESS_TIMEOUT = int(get_env('FILES_ACCESS_TIMEOUT'))
# Your App secret key will be used for securely signing the session cookie
# Make sure you are changing this key for your deployment with a strong key.
# You can generate a strong key using `openssl rand -base64 42`.
@ -177,7 +187,8 @@ class Config:
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
'max_overflow': int(get_env('SQLALCHEMY_MAX_OVERFLOW')),
'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')),
'pool_pre_ping': get_bool_env('SQLALCHEMY_POOL_PRE_PING')
'pool_pre_ping': get_bool_env('SQLALCHEMY_POOL_PRE_PING'),
'connect_args': {'options': '-c timezone=UTC'},
}
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
@ -201,27 +212,41 @@ class Config:
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
# ------------------------
# Code Execution Sandbox Configurations.
# ------------------------
self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
# ------------------------
# File Storage Configurations.
# ------------------------
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
# S3 Storage settings
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
self.S3_REGION = get_env('S3_REGION')
self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE')
# Azure Blob Storage settings
self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME')
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
self.ALIYUN_OSS_BUCKET_NAME=get_env('ALIYUN_OSS_BUCKET_NAME')
self.ALIYUN_OSS_ACCESS_KEY=get_env('ALIYUN_OSS_ACCESS_KEY')
self.ALIYUN_OSS_SECRET_KEY=get_env('ALIYUN_OSS_SECRET_KEY')
self.ALIYUN_OSS_ENDPOINT=get_env('ALIYUN_OSS_ENDPOINT')
self.ALIYUN_OSS_REGION=get_env('ALIYUN_OSS_REGION')
self.ALIYUN_OSS_AUTH_VERSION=get_env('ALIYUN_OSS_AUTH_VERSION')
# Aliyun Storage settings
self.ALIYUN_OSS_BUCKET_NAME = get_env('ALIYUN_OSS_BUCKET_NAME')
self.ALIYUN_OSS_ACCESS_KEY = get_env('ALIYUN_OSS_ACCESS_KEY')
self.ALIYUN_OSS_SECRET_KEY = get_env('ALIYUN_OSS_SECRET_KEY')
self.ALIYUN_OSS_ENDPOINT = get_env('ALIYUN_OSS_ENDPOINT')
self.ALIYUN_OSS_REGION = get_env('ALIYUN_OSS_REGION')
self.ALIYUN_OSS_AUTH_VERSION = get_env('ALIYUN_OSS_AUTH_VERSION')
# Google Cloud Storage settings
self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME')
self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
@ -231,6 +256,7 @@ class Config:
# ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE')
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
@ -273,6 +299,21 @@ class Config:
self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
# tidb-vector settings
self.TIDB_VECTOR_HOST = get_env('TIDB_VECTOR_HOST')
self.TIDB_VECTOR_PORT = get_env('TIDB_VECTOR_PORT')
self.TIDB_VECTOR_USER = get_env('TIDB_VECTOR_USER')
self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD')
self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE')
# chroma settings
self.CHROMA_HOST = get_env('CHROMA_HOST')
self.CHROMA_PORT = get_env('CHROMA_PORT')
self.CHROMA_TENANT = get_env('CHROMA_TENANT')
self.CHROMA_DATABASE = get_env('CHROMA_DATABASE')
self.CHROMA_AUTH_PROVIDER = get_env('CHROMA_AUTH_PROVIDER')
self.CHROMA_AUTH_CREDENTIALS = get_env('CHROMA_AUTH_CREDENTIALS')
# ------------------------
# Mail Configurations.
# ------------------------
@ -286,7 +327,8 @@ class Config:
self.SMTP_USERNAME = get_env('SMTP_USERNAME')
self.SMTP_PASSWORD = get_env('SMTP_PASSWORD')
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
self.SMTP_OPPORTUNISTIC_TLS = get_bool_env('SMTP_OPPORTUNISTIC_TLS')
# ------------------------
# Workspace Configurations.
# ------------------------
@ -313,6 +355,23 @@ class Config:
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
# RAG ETL Configurations.
self.ETL_TYPE = get_env('ETL_TYPE')
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY')
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
# Indexing Configurations.
self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH')
# Tool Configurations.
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
self.WORKFLOW_MAX_EXECUTION_STEPS = int(get_env('WORKFLOW_MAX_EXECUTION_STEPS'))
self.WORKFLOW_MAX_EXECUTION_TIME = int(get_env('WORKFLOW_MAX_EXECUTION_TIME'))
self.WORKFLOW_CALL_MAX_DEPTH = int(get_env('WORKFLOW_CALL_MAX_DEPTH'))
# Moderation in app Configurations.
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
@ -364,24 +423,15 @@ class Config:
self.HOSTED_FETCH_APP_TEMPLATES_MODE = get_env('HOSTED_FETCH_APP_TEMPLATES_MODE')
self.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = get_env('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN')
self.ETL_TYPE = get_env('ETL_TYPE')
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY')
# Model Load Balancing Configurations.
self.MODEL_LB_ENABLED = get_bool_env('MODEL_LB_ENABLED')
# Platform Billing Configurations.
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
# ------------------------
# Enterprise feature Configurations.
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
# ------------------------
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
# ------------------------
# Indexing Configurations.
# ------------------------
self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH')
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')

File diff suppressed because one or more lines are too long

View File

@ -54,4 +54,4 @@ from .explore import (
from .tag import tags
# Import workspace controllers
from .workspace import account, members, model_providers, models, tool_providers, workspace
from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace

View File

@ -85,7 +85,7 @@ class ChatMessageTextApi(Resource):
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
voice=request.form.get('voice'),
voice=request.form['voice'],
streaming=False
)

View File

@ -137,6 +137,71 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
logging.exception("internal server error.")
raise InternalServerError()
class AdvancedChatDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, location='json')
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class WorkflowDraftRunIterationNodeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, location='json')
args = parser.parse_args()
try:
response = AppGenerateService.generate_single_iteration(
app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except ValueError as e:
raise e
except Exception as e:
logging.exception("internal server error.")
raise InternalServerError()
class DraftWorkflowRunApi(Resource):
@setup_required
@ -326,6 +391,8 @@ api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-
api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'

View File

@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from fields.app_fields import related_app_list
@ -476,20 +477,22 @@ class DatasetRetrievalSettingApi(Resource):
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}:
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type in {"qdrant", "weaviate"}:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
return {
'retrieval_method': [
'semantic_search'
]
}
case VectorType.QDRANT | VectorType.WEAVIATE:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
class DatasetRetrievalSettingMockApi(Resource):
@ -497,20 +500,22 @@ class DatasetRetrievalSettingMockApi(Resource):
@login_required
@account_initialization_required
def get(self, vector_type):
if vector_type in {'milvus', 'relyt', 'pgvector'}:
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type in {'qdrant', 'weaviate'}:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
return {
'retrieval_method': [
'semantic_search'
]
}
case VectorType.QDRANT | VectorType.WEAVIATE:
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
case _:
raise ValueError(f"Unsupported vector db type {vector_type}.")
class DatasetErrorDocs(Resource):
@setup_required

View File

@ -1,10 +1,12 @@
import logging
from argparse import ArgumentTypeError
from datetime import datetime, timezone
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc
from transformers.hf_argparser import string_to_bool
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -141,7 +143,11 @@ class DatasetDocumentListApi(Resource):
limit = request.args.get('limit', default=20, type=int)
search = request.args.get('keyword', default=None, type=str)
sort = request.args.get('sort', default='-created_at', type=str)
fetch = request.args.get('fetch', default=False, type=bool)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch = string_to_bool(request.args.get('fetch', default='false'))
except (ArgumentTypeError, ValueError, Exception) as e:
fetch = False
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
@ -924,6 +930,28 @@ class DocumentRetryApi(DocumentResource):
return {'result': 'success'}, 204
class DocumentRenameApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(document_fields)
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
args = parser.parse_args()
try:
document = DocumentService.rename_document(dataset_id, document_id, args['name'])
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError('Cannot delete document during indexing.')
return document
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
api.add_resource(DatasetDocumentListApi,
'/datasets/<uuid:dataset_id>/documents')
@ -950,3 +978,5 @@ api.add_resource(DocumentStatusApi,
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
api.add_resource(DocumentRenameApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename')

View File

@ -76,7 +76,7 @@ class ChatTextApi(InstalledAppResource):
response = AudioService.transcript_tts(
app_model=app_model,
text=request.form['text'],
voice=request.form.get('voice'),
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
return {'data': response.data.decode('latin1')}

View File

@ -1,14 +1,19 @@
from flask_login import current_user
from flask_restful import Resource
from libs.login import login_required
from services.feature_service import FeatureService
from . import api
from .wraps import cloud_utm_record
from .setup import setup_required
from .wraps import account_initialization_required, cloud_utm_record
class FeatureApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
return FeatureService.get_features(current_user.current_tenant_id).dict()

View File

@ -17,13 +17,19 @@ class VersionApi(Resource):
args = parser.parse_args()
check_update_url = current_app.config['CHECK_UPDATE_URL']
if not check_update_url:
return {
'version': '0.0.0',
'release_date': '',
'release_notes': '',
'can_auto_update': False
result = {
'version': current_app.config['CURRENT_VERSION'],
'release_date': '',
'release_notes': '',
'can_auto_update': False,
'features': {
'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'],
'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED']
}
}
if not check_update_url:
return result
try:
response = requests.get(check_update_url, {
@ -31,20 +37,15 @@ class VersionApi(Resource):
})
except Exception as error:
logging.warning("Check update version error: {}.".format(str(error)))
return {
'version': args.get('current_version'),
'release_date': '',
'release_notes': '',
'can_auto_update': False
}
result['version'] = args.get('current_version')
return result
content = json.loads(response.content)
return {
'version': content['version'],
'release_date': content['releaseDate'],
'release_notes': content['releaseNotes'],
'can_auto_update': content['canAutoUpdate']
}
result['version'] = content['version']
result['release_date'] = content['releaseDate']
result['release_notes'] = content['releaseNotes']
result['can_auto_update'] = content['canAutoUpdate']
return result
api.add_resource(VersionApi, '/version')

View File

@ -0,0 +1,106 @@
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required
from models.account import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
class LoadBalancingCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
# validate model load balancing credentials
model_load_balancing_service = ModelLoadBalancingService()
result = True
error = None
try:
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
class LoadBalancingConfigCredentialsValidateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args()
# validate model load balancing config credentials
model_load_balancing_service = ModelLoadBalancingService()
result = True
error = None
try:
model_load_balancing_service.validate_load_balancing_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials'],
config_id=config_id,
)
except CredentialsValidateFailedError as ex:
result = False
error = str(ex)
response = {'result': 'success' if result else 'error'}
if not result:
response['error'] = error
return response
# Load Balancing Config
api.add_resource(LoadBalancingCredentialsValidateApi,
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate')
api.add_resource(LoadBalancingConfigCredentialsValidateApi,
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate')

View File

@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
@ -104,21 +105,56 @@ class ModelProviderModelApi(Resource):
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json')
parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json')
parser.add_argument('config_from', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_load_balancing_service = ModelLoadBalancingService()
try:
model_provider_service.save_model_credentials(
if ('load_balancing' in args and args['load_balancing'] and
'enabled' in args['load_balancing'] and args['load_balancing']['enabled']):
if 'configs' not in args['load_balancing']:
raise ValueError('invalid load balancing configs')
# save load balancing configs
model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
configs=args['load_balancing']['configs']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
# enable load balancing
model_load_balancing_service.enable_model_load_balancing(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
else:
# disable load balancing
model_load_balancing_service.disable_model_load_balancing(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
if args.get('config_from', '') != 'predefined-model':
model_provider_service = ModelProviderService()
try:
model_provider_service.save_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type'],
credentials=args['credentials']
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {'result': 'success'}, 200
@ -170,11 +206,73 @@ class ModelProviderModelCredentialApi(Resource):
model=args['model']
)
model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
return {
"credentials": credentials
"credentials": credentials,
"load_balancing": {
"enabled": is_load_balancing_enabled,
"configs": load_balancing_configs
}
}
class ModelProviderModelEnableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.enable_model(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
return {'result': 'success'}
class ModelProviderModelDisableApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, provider: str):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=[mt.value for mt in ModelType], location='json')
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.disable_model(
tenant_id=tenant_id,
provider=provider,
model=args['model'],
model_type=args['model_type']
)
return {'result': 'success'}
class ModelProviderModelValidateApi(Resource):
@setup_required
@ -259,6 +357,10 @@ class ModelProviderAvailableModelApi(Resource):
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers/<string:provider>/models/enable',
endpoint='model-provider-model-enable')
api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers/<string:provider>/models/disable',
endpoint='model-provider-model-disable')
api.add_resource(ModelProviderModelCredentialApi,
'/workspaces/current/model-providers/<string:provider>/models/credentials')
api.add_resource(ModelProviderModelValidateApi,

View File

@ -9,8 +9,13 @@ from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.tools_manage_service import ToolManageService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
class ToolProviderListApi(Resource):
@ -21,7 +26,11 @@ class ToolProviderListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return ToolManageService.list_tool_providers(user_id, tenant_id)
req = reqparse.RequestParser()
req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args')
args = req.parse_args()
return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None))
class ToolBuiltinProviderListToolsApi(Resource):
@setup_required
@ -31,7 +40,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder(ToolManageService.list_builtin_tool_provider_tools(
return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools(
user_id,
tenant_id,
provider,
@ -48,7 +57,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return ToolManageService.delete_builtin_tool_provider(
return BuiltinToolManageService.delete_builtin_tool_provider(
user_id,
tenant_id,
provider,
@ -70,7 +79,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
args = parser.parse_args()
return ToolManageService.update_builtin_tool_provider(
return BuiltinToolManageService.update_builtin_tool_provider(
user_id,
tenant_id,
provider,
@ -85,7 +94,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return ToolManageService.get_builtin_tool_provider_credentials(
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
user_id,
tenant_id,
provider,
@ -94,7 +103,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
class ToolBuiltinProviderIconApi(Resource):
@setup_required
def get(self, provider):
icon_bytes, mimetype = ToolManageService.get_builtin_tool_provider_icon(provider)
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE'))
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
@ -116,11 +125,12 @@ class ToolApiProviderAddApi(Resource):
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json')
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[])
parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json')
args = parser.parse_args()
return ToolManageService.create_api_tool_provider(
return ApiToolManageService.create_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@ -130,6 +140,7 @@ class ToolApiProviderAddApi(Resource):
args['schema'],
args.get('privacy_policy', ''),
args.get('custom_disclaimer', ''),
args.get('labels', []),
)
class ToolApiProviderGetRemoteSchemaApi(Resource):
@ -143,7 +154,7 @@ class ToolApiProviderGetRemoteSchemaApi(Resource):
args = parser.parse_args()
return ToolManageService.get_api_tool_provider_remote_schema(
return ApiToolManageService.get_api_tool_provider_remote_schema(
current_user.id,
current_user.current_tenant_id,
args['url'],
@ -163,7 +174,7 @@ class ToolApiProviderListToolsApi(Resource):
args = parser.parse_args()
return jsonable_encoder(ToolManageService.list_api_tool_provider_tools(
return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
@ -188,11 +199,12 @@ class ToolApiProviderUpdateApi(Resource):
parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json')
parser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json')
parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json')
args = parser.parse_args()
return ToolManageService.update_api_tool_provider(
return ApiToolManageService.update_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@ -203,6 +215,7 @@ class ToolApiProviderUpdateApi(Resource):
args['schema'],
args['privacy_policy'],
args['custom_disclaimer'],
args.get('labels', []),
)
class ToolApiProviderDeleteApi(Resource):
@ -222,7 +235,7 @@ class ToolApiProviderDeleteApi(Resource):
args = parser.parse_args()
return ToolManageService.delete_api_tool_provider(
return ApiToolManageService.delete_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@ -242,7 +255,7 @@ class ToolApiProviderGetApi(Resource):
args = parser.parse_args()
return ToolManageService.get_api_tool_provider(
return ApiToolManageService.get_api_tool_provider(
user_id,
tenant_id,
args['provider'],
@ -253,7 +266,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
return ToolManageService.list_builtin_provider_credentials_schema(provider)
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider)
class ToolApiProviderSchemaApi(Resource):
@setup_required
@ -266,7 +279,7 @@ class ToolApiProviderSchemaApi(Resource):
args = parser.parse_args()
return ToolManageService.parser_api_schema(
return ApiToolManageService.parser_api_schema(
schema=args['schema'],
)
@ -286,7 +299,7 @@ class ToolApiProviderPreviousTestApi(Resource):
args = parser.parse_args()
return ToolManageService.test_api_tool_preview(
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args['provider_name'] if args['provider_name'] else '',
args['tool_name'],
@ -296,6 +309,153 @@ class ToolApiProviderPreviousTestApi(Resource):
args['schema'],
)
class ToolWorkflowProviderCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json')
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
args = reqparser.parse_args()
return WorkflowToolManageService.create_workflow_tool(
user_id,
tenant_id,
args['workflow_app_id'],
args['name'],
args['label'],
args['icon'],
args['description'],
args['parameters'],
args['privacy_policy'],
args.get('labels', []),
)
class ToolWorkflowProviderUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json')
reqparser.add_argument('label', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('description', type=str, required=True, nullable=False, location='json')
reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json')
reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json')
reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='')
reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json')
args = reqparser.parse_args()
if not args['workflow_tool_id']:
raise ValueError('incorrect workflow_tool_id')
return WorkflowToolManageService.update_workflow_tool(
user_id,
tenant_id,
args['workflow_tool_id'],
args['name'],
args['label'],
args['icon'],
args['description'],
args['parameters'],
args['privacy_policy'],
args.get('labels', []),
)
class ToolWorkflowProviderDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id
tenant_id = current_user.current_tenant_id
reqparser = reqparse.RequestParser()
reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json')
args = reqparser.parse_args()
return WorkflowToolManageService.delete_workflow_tool(
user_id,
tenant_id,
args['workflow_tool_id'],
)
class ToolWorkflowProviderGetApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args')
parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args')
args = parser.parse_args()
if args.get('workflow_tool_id'):
tool = WorkflowToolManageService.get_workflow_tool_by_tool_id(
user_id,
tenant_id,
args['workflow_tool_id'],
)
elif args.get('workflow_app_id'):
tool = WorkflowToolManageService.get_workflow_tool_by_app_id(
user_id,
tenant_id,
args['workflow_app_id'],
)
else:
raise ValueError('incorrect workflow_tool_id or workflow_app_id')
return jsonable_encoder(tool)
class ToolWorkflowProviderListToolApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args')
args = parser.parse_args()
return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools(
user_id,
tenant_id,
args['workflow_tool_id'],
))
class ToolBuiltinListApi(Resource):
@setup_required
@login_required
@ -304,7 +464,7 @@ class ToolBuiltinListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_builtin_tools(
return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)])
@ -317,18 +477,43 @@ class ToolApiListApi(Resource):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder([provider.to_dict() for provider in ToolManageService.list_api_tools(
return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)])
class ToolWorkflowListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools(
user_id,
tenant_id,
)])
class ToolLabelsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
# tool provider
api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools')
api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete')
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
# api tool provider
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
@ -338,5 +523,15 @@ api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/g
api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema')
api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre')
# workflow tool provider
api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create')
api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update')
api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete')
api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get')
api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools')
api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin')
api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
api.add_resource(ToolApiListApi, '/workspaces/current/tools/api')
api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow')
api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels')

View File

@ -97,7 +97,7 @@ class MessageListApi(Resource):
class MessageFeedbackApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
@ -114,7 +114,7 @@ class MessageFeedbackApi(Resource):
class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)

View File

@ -1,5 +1,6 @@
from flask import request
from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound
import services.dataset_service
from controllers.service_api import api
@ -19,10 +20,12 @@ def _validate_name(name):
return name
class DatasetApi(DatasetApiResource):
"""Resource for get datasets."""
class DatasetListApi(DatasetApiResource):
"""Resource for datasets."""
def get(self, tenant_id):
"""Resource for getting datasets."""
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
provider = request.args.get('provider', default="vendor")
@ -65,9 +68,9 @@ class DatasetApi(DatasetApiResource):
}
return response, 200
"""Resource for datasets."""
def post(self, tenant_id):
"""Resource for creating datasets."""
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='type is required. Name must be between 1 to 40 characters.',
@ -89,6 +92,31 @@ class DatasetApi(DatasetApiResource):
return marshal(dataset, dataset_detail_fields), 200
class DatasetApi(DatasetApiResource):
"""Resource for dataset."""
api.add_resource(DatasetApi, '/datasets')
def delete(self, _, dataset_id):
"""
Deletes a dataset given its ID.
Args:
dataset_id (UUID): The ID of the dataset to be deleted.
Returns:
dict: A dictionary with a key 'result' and a value 'success'
if the dataset was successfully deleted. Omitted in HTTP response.
int: HTTP status code 204 indicating that the operation was successful.
Raises:
NotFound: If the dataset with the given ID does not exist.
"""
dataset_id_str = str(dataset_id)
if DatasetService.delete_dataset(dataset_id_str, current_user):
return {'result': 'success'}, 204
else:
raise NotFound("Dataset not found.")
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')

View File

@ -8,7 +8,7 @@ from flask import current_app, request
from flask_login import user_logged_in
from flask_restful import Resource
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db
from libs.login import _get_user
@ -39,17 +39,17 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
if not app_model:
raise NotFound()
raise Forbidden("The app no longer exists.")
if app_model.status != 'normal':
raise NotFound()
raise Forbidden("The app's status is abnormal.")
if not app_model.enable_api:
raise NotFound()
raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
if tenant.status == TenantStatus.ARCHIVE:
raise NotFound()
raise Forbidden("The workspace's status is archived.")
kwargs['app_model'] = app_model

View File

@ -74,7 +74,7 @@ class TextApi(WebApiResource):
app_model=app_model,
text=request.form['text'],
end_user=end_user.external_user_id,
voice=request.form.get('voice'),
voice=request.form['voice'] if request.form.get('voice') else None,
streaming=False
)

View File

@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import (
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from extensions.ext_database import db
from models.model import Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
@ -128,6 +129,8 @@ class BaseAgentRunner(AppRunner):
self.files = application_generate_entity.files
else:
self.files = []
self.query = None
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
@ -165,6 +168,7 @@ class BaseAgentRunner(AppRunner):
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from
)
tool_entity.load_variables(self.variables_pool)
@ -183,21 +187,11 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
parameter_type = 'string'
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
parameter_type = 'boolean'
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
parameter_type = 'number'
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
for option in parameter.options:
enum.append(option.value)
parameter_type = 'string'
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
@ -278,20 +272,10 @@ class BaseAgentRunner(AppRunner):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
parameter_type = 'string'
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
parameter_type = 'boolean'
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
parameter_type = 'number'
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
for option in parameter.options:
enum.append(option.value)
parameter_type = 'string'
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
@ -463,7 +447,7 @@ class BaseAgentRunner(AppRunner):
for message in messages:
if message.id == self.message.id:
continue
result.append(self.organize_agent_user_prompt(message))
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:

View File

@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
@ -373,7 +374,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
@ -381,6 +382,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpad: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit = None
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(

View File

@ -32,9 +32,6 @@ class CotChatAgentRunner(CotAgentRunner):
# organize system prompt
system_message = self._organize_system_prompt()
# organize historic prompt messages
historic_messages = self._historic_prompt_messages
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
if not agent_scratchpad:
@ -57,6 +54,13 @@ class CotChatAgentRunner(CotAgentRunner):
query_messages = UserPromptMessage(content=self._query)
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([
system_message,
query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
messages = [
system_message,
*historic_messages,
@ -65,6 +69,8 @@ class CotChatAgentRunner(CotAgentRunner):
UserPromptMessage(content='continue')
]
else:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
messages = [system_message, *historic_messages, query_messages]
# join all messages

View File

@ -19,11 +19,11 @@ class CotCompletionAgentRunner(CotAgentRunner):
return system_prompt
def _organize_historic_prompt(self) -> str:
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
"""
Organize historic prompt
"""
historic_prompt_messages = self._historic_prompt_messages
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
historic_prompt = ""
for message in historic_prompt_messages:

View File

@ -8,7 +8,7 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api"]
provider_type: Literal["builtin", "api", "workflow"]
provider_id: str
tool_name: str
tool_parameters: dict[str, Any] = {}

View File

@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
@ -24,21 +25,18 @@ from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
prompt_template = app_config.prompt_template.simple_prompt_template or ''
prompt_messages = self.history_prompt_messages
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
prompt_messages = self._organize_user_query(query, prompt_messages)
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
@ -81,6 +79,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
@ -203,7 +202,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
else:
assistant_message.content = response
prompt_messages.append(assistant_message)
self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
@ -265,12 +264,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
}
tool_responses.append(tool_response)
prompt_messages = self._organize_assistant_message(
tool_call_id=tool_call_id,
tool_call_name=tool_call_name,
tool_response=tool_response['tool_response'],
prompt_messages=prompt_messages,
)
if tool_response['tool_response'] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response['tool_response'],
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
if len(tool_responses) > 0:
# save agent thought
@ -300,8 +301,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
iteration_step += 1
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
@ -393,24 +392,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return prompt_messages
def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize assistant message
"""
prompt_messages = deepcopy(prompt_messages)
if tool_response is not None:
prompt_messages.append(
ToolPromptMessage(
content=tool_response,
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
@ -428,4 +409,26 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for content in prompt_message.content
])
return prompt_messages
return prompt_messages
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory
).get_prompt()
prompt_messages = [
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@ -1,7 +1,7 @@
from typing import Optional
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
class AgentConfigManager:

View File

@ -239,4 +239,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
"""
Workflow UI Based App Config Entity.
"""
workflow_id: str
workflow_id: str

View File

@ -98,6 +98,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
extras=extras
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
inputs={},
query='',
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
)
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
)
def _generate(self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation = None,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
is_first_conversation = False
if not conversation:
is_first_conversation = True
@ -167,18 +251,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""
with flask_app.app_context():
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AdvancedChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AdvancedChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:

View File

@ -102,6 +102,7 @@ class AdvancedChatAppRunner(AppRunner):
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.QUERY: query,
@ -109,6 +110,35 @@ class AdvancedChatAppRunner(AppRunner):
SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
)

View File

@ -12,6 +12,9 @@ from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
@ -64,6 +67,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
@ -103,6 +107,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
usage=LLMUsage.empty_usage()
)
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
@ -204,6 +209,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# reset current route position to 0
self._task_state.current_stream_generate_state.current_route_position = 0
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
@ -225,6 +232,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
if workflow_run:
@ -263,10 +286,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
# elif isinstance(event, QueueMessageFileEvent):
# response = self._message_file_to_stream_response(event)
# if response:
# yield response
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
@ -342,7 +361,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
id=self._message.id,
**extras
)
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
@ -372,7 +391,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
@ -391,6 +410,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
ingoing_edges.append(edge)
if not ingoing_edges:
# check if it's the first node in the iteration
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
if not target_node:
return []
node_iteration_id = target_node.get('data', {}).get('iteration_id')
# get iteration start node id
for node in nodes:
if node.get('id') == node_iteration_id:
if node.get('data', {}).get('start_node_id') == target_node_id:
return [target_node_id]
return []
start_node_ids = []
@ -401,14 +432,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value:
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
@ -417,7 +457,27 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
@ -425,7 +485,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
self._task_state.current_stream_generate_state.current_route_position:
]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
@ -458,13 +519,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
@ -476,6 +538,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if route_chunk_node_id == 'sys':
# system variable
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
continue
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
iterator = iteration_state.inputs
if not iterator:
continue
iterator_selector = iterator.get('iterator_selector', [])
if value_selector[1] == 'index':
value = iteration_state.current_index
elif value_selector[1] == 'item':
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
iterator_selector) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
@ -505,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else:
value = value.get(key)
if value:
if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)

View File

@ -1,8 +1,11 @@
from typing import Optional
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager._publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager._publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event

View File

@ -115,7 +115,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras=extras
extras=extras,
call_depth=0
)
# init generate records

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -16,11 +16,11 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.file_obj import FileVar
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
@ -45,8 +45,11 @@ class AppRunner:
:param query: query
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
@ -73,9 +76,7 @@ class AppRunner:
query=query
)
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
@ -89,8 +90,10 @@ class AppRunner:
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
@ -107,9 +110,7 @@ class AppRunner:
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)

View File

@ -34,7 +34,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True) \
stream: bool = True,
call_depth: int = 0) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -75,9 +76,38 @@ class WorkflowAppGenerator(BaseAppGenerator):
files=file_objs,
user_id=user.id,
stream=stream,
invoke_from=invoke_from
invoke_from=invoke_from,
call_depth=call_depth
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
call_depth=call_depth
)
def _generate(self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@ -109,6 +139,64 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from=invoke_from
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param user: account or end user
:param args: request args
:param invoke_from: invoke from source
:param stream: is stream
"""
if not node_id:
raise ValueError('node_id is required')
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
inputs={},
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
)
)
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
stream=stream
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager) -> None:
@ -123,10 +211,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
try:
# workflow app
runner = WorkflowAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
)
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:

View File

@ -73,11 +73,44 @@ class WorkflowAppRunner:
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
user_inputs=inputs,
system_inputs={
SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id
},
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth
)
def single_iteration_run(self, app_id: str, workflow_id: str,
queue_manager: AppQueueManager,
inputs: dict, node_id: str, user_id: str) -> None:
"""
Single iteration run
"""
app_record: App = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError("App not found")
if not app_record.workflow_id:
raise ValueError("Workflow not initialized")
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
workflow_callbacks = [WorkflowEventTriggerCallback(
queue_manager=queue_manager,
workflow=workflow
)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow,
node_id=node_id,
user_id=user_id,
user_inputs=inputs,
callbacks=workflow_callbacks
)

View File

@ -9,6 +9,9 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import (
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
@ -58,6 +61,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariable, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
@ -85,8 +89,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
SystemVariable.USER_ID: user_id
}
self._task_state = WorkflowTaskState()
self._task_state = WorkflowTaskState(
iteration_nested_node_ids=[]
)
self._stream_generate_nodes = self._get_stream_generate_nodes()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -191,6 +198,22 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(event)
@ -331,13 +354,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value:
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
@ -411,3 +441,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
return False
return True
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

View File

@ -1,8 +1,11 @@
from typing import Optional
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
@ -130,6 +133,66 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager.publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager.publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event

View File

@ -102,6 +102,39 @@ class WorkflowLoggingCallback(BaseWorkflowCallback):
self.print_text(text, color="pink", end="")
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self.print_text("\n[on_workflow_iteration_started]", color='blue')
self.print_text(f"Node ID: {node_id}", color='blue')
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[dict]) -> None:
"""
Publish iteration next
"""
self.print_text("\n[on_workflow_iteration_next]", color='blue')
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event

View File

@ -80,6 +80,9 @@ class AppGenerateEntity(BaseModel):
stream: bool
invoke_from: InvokeFrom
# invoke call depth
call_depth: int = 0
# extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {}
@ -126,6 +129,14 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
conversation_id: Optional[str] = None
query: Optional[str] = None
class SingleIterationRunEntity(BaseModel):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
@ -133,3 +144,12 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
# app config
app_config: WorkflowUIBasedAppConfig
class SingleIterationRunEntity(BaseModel):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None

View File

@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from pydantic import BaseModel, validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -21,6 +21,9 @@ class QueueEvent(Enum):
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
NODE_STARTED = "node_started"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
@ -47,6 +50,55 @@ class QueueLLMChunkEvent(AppQueueEvent):
event = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
class QueueIterationStartEvent(AppQueueEvent):
"""
QueueIterationStartEvent entity
"""
event = QueueEvent.ITERATION_START
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int
inputs: dict = None
predecessor_node_id: Optional[str] = None
metadata: Optional[dict] = None
class QueueIterationNextEvent(AppQueueEvent):
"""
QueueIterationNextEvent entity
"""
event = QueueEvent.ITERATION_NEXT
index: int
node_id: str
node_type: NodeType
node_run_index: int
output: Optional[Any] # output for the current iteration
@validator('output', pre=True, always=True)
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError('output must be a valid type')
class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event = QueueEvent.ITERATION_COMPLETED
node_id: str
node_type: NodeType
node_run_index: int
outputs: dict
class QueueTextChunkEvent(AppQueueEvent):
"""

View File

@ -1,12 +1,14 @@
from enum import Enum
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
from models.workflow import WorkflowNodeExecutionStatus
class WorkflowStreamGenerateNodes(BaseModel):
@ -65,6 +67,7 @@ class WorkflowTaskState(TaskState):
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
@ -91,6 +94,9 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
@ -319,6 +325,74 @@ class NodeFinishStreamResponse(StreamResponse):
}
}
class IterationNodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
created_at: int
extras: dict = {}
metadata: dict = {}
inputs: dict = {}
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
data: Data
class IterationNodeNextStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
index: int
created_at: int
pre_iteration_output: Optional[Any]
extras: dict = {}
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
data: Data
class IterationNodeCompletedStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
outputs: Optional[dict]
created_at: int
extras: dict = None
inputs: dict = None
status: WorkflowNodeExecutionStatus
error: Optional[str]
elapsed_time: float
total_tokens: int
finished_at: int
steps: int
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
data: Data
class TextChunkStreamResponse(StreamResponse):
"""
@ -454,3 +528,23 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class WorkflowIterationState(BaseModel):
"""
WorkflowIterationState entity
"""
class Data(BaseModel):
"""
Data entity
"""
parent_iteration_id: Optional[str] = None
iteration_id: str
current_index: int
iteration_steps_boundary: list[int] = None
node_execution_id: str
started_at: float
inputs: dict = None
total_tokens: int = 0
node_data: BaseNodeData
current_iterations: dict[str, Data] = None

View File

@ -37,6 +37,7 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@ -317,29 +318,30 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
"""
model_config = self._model_config
model = model_config.model
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
prompt_tokens = model_instance.get_llm_num_tokens(
self._task_state.llm_result.prompt_messages
)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_type_instance.get_num_tokens(
model,
model_config.credentials,
completion_tokens = model_instance.get_llm_num_tokens(
[self._task_state.llm_result.message]
)
credentials = model_config.credentials
# transform usage
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model,
credentials,

View File

@ -1,9 +1,9 @@
import json
import time
from datetime import datetime, timezone
from typing import Any, Optional, Union, cast
from typing import Optional, Union, cast
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
@ -13,18 +13,17 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
NodeExecutionInfo,
NodeFinishStreamResponse,
NodeStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
@ -42,13 +41,7 @@ from models.workflow import (
)
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]
class WorkflowCycleManage(WorkflowIterationCycleManage):
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
@ -237,6 +230,7 @@ class WorkflowCycleManage:
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
@ -255,6 +249,8 @@ class WorkflowCycleManage:
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
db.session.commit()
db.session.refresh(workflow_node_execution)
@ -444,6 +440,23 @@ class WorkflowCycleManage:
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
if self._iteration_state and self._iteration_state.current_iterations:
if not execution_metadata:
execution_metadata = {}
current_iteration_data = None
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if data.parent_iteration_id == None:
current_iteration_data = data
break
if current_iteration_data:
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
@ -451,12 +464,18 @@ class WorkflowCycleManage:
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
execution_metadata=event.execution_metadata
execution_metadata=execution_metadata
)
if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
if self._iteration_state:
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
@ -469,7 +488,8 @@ class WorkflowCycleManage:
error=event.error,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs
outputs=event.outputs,
execution_metadata=execution_metadata
)
db.session.close()

View File

@ -0,0 +1,16 @@
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.entities.node_entities import SystemVariable
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
class WorkflowCycleStateManager:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]

View File

@ -0,0 +1,281 @@
import json
import time
from collections.abc import Generator
from typing import Optional, Union
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
)
from core.app.entities.task_entities import (
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeExecutionInfo,
WorkflowIterationState,
)
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
from core.workflow.entities.node_entities import NodeType
from extensions.ext_database import db
from models.workflow import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
)
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
_iteration_state: WorkflowIterationState = None
def _init_iteration_state(self) -> WorkflowIterationState:
if not self._iteration_state:
self._iteration_state = WorkflowIterationState(
current_iterations={}
)
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
"""
Handle iteration to stream response
:param task_id: task id
:param event: iteration event
:return:
"""
if isinstance(event, QueueIterationStartEvent):
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs,
metadata=event.metadata
)
)
elif isinstance(event, QueueIterationNextEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={}
)
)
elif isinstance(event, QueueIterationCompletedEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)
def _init_iteration_execution_from_workflow_run(self,
workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
inputs=json.dumps(inputs) if inputs else None,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
execution_metadata=json.dumps({
'started_run_index': node_run_index + 1,
'current_index': 0,
'steps_boundary': [],
})
)
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
if isinstance(event, QueueIterationStartEvent):
return self._handle_iteration_started(event)
elif isinstance(event, QueueIterationNextEvent):
return self._handle_iteration_next(event)
elif isinstance(event, QueueIterationCompletedEvent):
return self._handle_iteration_completed(event)
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
self._init_iteration_state()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=NodeType.ITERATION,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id
)
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
parent_iteration_id=None,
iteration_id=event.node_id,
current_index=0,
iteration_steps_boundary=[],
node_execution_id=workflow_node_execution.id,
started_at=time.perf_counter(),
inputs=event.inputs,
total_tokens=0,
node_data=event.node_data
)
db.session.close()
return workflow_node_execution
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
current_iteration.current_index = event.index
current_iteration.iteration_steps_boundary.append(event.node_run_index)
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['current_index'] = event.index
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
db.session.close()
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent) -> WorkflowNodeExecution:
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.outputs = json.dumps(event.outputs) if event.outputs else None
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
# remove current iteration
self._iteration_state.current_iterations.pop(event.node_id, None)
# set latest node execution info
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.latest_node_execution_info = latest_node_execution_info
db.session.close()
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
"""
Handle iteration exception
"""
if not self._iteration_state or not self._iteration_state.current_iterations:
return
for node_id, current_iteration in self._iteration_state.current_iterations.items():
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
db.session.commit()
db.session.close()
yield IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=node_id,
node_id=node_id,
node_type=NodeType.ITERATION.value,
title=current_iteration.node_data.title,
outputs={},
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.FAILED,
error=error,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)

View File

@ -16,6 +16,7 @@ class ModelStatus(Enum):
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
DISABLED = "disabled"
class SimpleModelProviderEntity(BaseModel):
@ -43,12 +44,19 @@ class SimpleModelProviderEntity(BaseModel):
)
class ModelWithProviderEntity(ProviderModel):
class ProviderModelWithStatusEntity(ProviderModel):
"""
Model class for model response.
"""
status: ModelStatus
load_balancing_enabled: bool = False
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
"""
Model with provider entity.
"""
provider: SimpleModelProviderEntity
status: ModelStatus
class DefaultModelProviderEntity(BaseModel):

View File

@ -1,6 +1,7 @@
import datetime
import json
import logging
from collections import defaultdict
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional
@ -8,7 +9,12 @@ from typing import Optional
from pydantic import BaseModel
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
from core.entities.provider_entities import (
CustomConfiguration,
ModelSettings,
SystemConfiguration,
SystemConfigurationStatus,
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
@ -22,7 +28,14 @@ from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from extensions.ext_database import db
from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
from models.provider import (
LoadBalancingModelConfig,
Provider,
ProviderModel,
ProviderModelSetting,
ProviderType,
TenantPreferredModelProvider,
)
logger = logging.getLogger(__name__)
@ -39,6 +52,7 @@ class ProviderConfiguration(BaseModel):
using_provider_type: ProviderType
system_configuration: SystemConfiguration
custom_configuration: CustomConfiguration
model_settings: list[ModelSettings]
def __init__(self, **data):
super().__init__(**data)
@ -62,6 +76,14 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
if self.model_settings:
# check if model is disabled by admin
for model_setting in self.model_settings:
if (model_setting.model_type == model_type
and model_setting.model == model):
if not model_setting.enabled:
raise ValueError(f'Model {model} is disabled.')
if self.using_provider_type == ProviderType.SYSTEM:
restrict_models = []
for quota_configuration in self.system_configuration.quota_configurations:
@ -80,15 +102,17 @@ class ProviderConfiguration(BaseModel):
return copy_credentials
else:
credentials = None
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
return model_configuration.credentials
credentials = model_configuration.credentials
break
if self.custom_configuration.provider:
return self.custom_configuration.provider.credentials
else:
return None
credentials = self.custom_configuration.provider.credentials
return credentials
def get_system_configuration_status(self) -> SystemConfigurationStatus:
"""
@ -130,7 +154,7 @@ class ProviderConfiguration(BaseModel):
return credentials
# Obfuscate credentials
return self._obfuscated_credentials(
return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
@ -151,7 +175,7 @@ class ProviderConfiguration(BaseModel):
).first()
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema else []
)
@ -274,7 +298,7 @@ class ProviderConfiguration(BaseModel):
return credentials
# Obfuscate credentials
return self._obfuscated_credentials(
return self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
@ -302,7 +326,7 @@ class ProviderConfiguration(BaseModel):
).first()
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema else []
)
@ -402,6 +426,160 @@ class ProviderConfiguration(BaseModel):
provider_model_credentials_cache.delete()
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Enable model.
:param model_type: model type
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
if model_setting:
model_setting.enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True
)
db.session.add(model_setting)
db.session.commit()
return model_setting
def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Disable model.
:param model_type: model type
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
if model_setting:
model_setting.enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False
)
db.session.add(model_setting)
db.session.commit()
return model_setting
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
"""
Get provider model setting.
:param model_type: model type
:param model: model name
:return:
"""
return db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Enable model load balancing.
:param model_type: model type
:param model: model name
:return:
"""
load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
.filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model
).count()
if load_balancing_config_count <= 1:
raise ValueError('Model load balancing configuration must be more than 1.')
model_setting = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
if model_setting:
model_setting.load_balancing_enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True
)
db.session.add(model_setting)
db.session.commit()
return model_setting
def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
"""
Disable model load balancing.
:param model_type: model type
:param model: model name
:return:
"""
model_setting = db.session.query(ProviderModelSetting) \
.filter(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name == self.provider.provider,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model
).first()
if model_setting:
model_setting.load_balancing_enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit()
else:
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False
)
db.session.add(model_setting)
db.session.commit()
return model_setting
def get_provider_instance(self) -> ModelProvider:
"""
Get provider instance.
@ -453,7 +631,7 @@ class ProviderConfiguration(BaseModel):
db.session.commit()
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
"""
Extract secret input form variables.
@ -467,7 +645,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
"""
Obfuscated credentials.
@ -476,7 +654,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
# Get provider credential secret variables
credential_secret_variables = self._extract_secret_variables(
credential_secret_variables = self.extract_secret_variables(
credential_form_schemas
)
@ -522,15 +700,22 @@ class ProviderConfiguration(BaseModel):
else:
model_types = provider_instance.get_provider_schema().supported_model_types
# Group model settings by model type and model
model_setting_map = defaultdict(dict)
for model_setting in self.model_settings:
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
if self.using_provider_type == ProviderType.SYSTEM:
provider_models = self._get_system_provider_models(
model_types=model_types,
provider_instance=provider_instance
provider_instance=provider_instance,
model_setting_map=model_setting_map
)
else:
provider_models = self._get_custom_provider_models(
model_types=model_types,
provider_instance=provider_instance
provider_instance=provider_instance,
model_setting_map=model_setting_map
)
if only_active:
@ -541,18 +726,27 @@ class ProviderConfiguration(BaseModel):
def _get_system_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-> list[ModelWithProviderEntity]:
"""
Get system provider models.
:param model_types: model types
:param provider_instance: provider instance
:param model_setting_map: model setting map
:return:
"""
provider_models = []
for model_type in model_types:
provider_models.extend(
[
for m in provider_instance.models(model_type):
status = ModelStatus.ACTIVE
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
provider_models.append(
ModelWithProviderEntity(
model=m.model,
label=m.label,
@ -562,11 +756,9 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
status=status
)
for m in provider_instance.models(model_type)
]
)
)
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
@ -586,7 +778,8 @@ class ProviderConfiguration(BaseModel):
break
if should_use_custom_model:
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
if original_provider_configurate_methods[self.provider.provider] == [
ConfigurateMethod.CUSTOMIZABLE_MODEL]:
# only customizable model
for restrict_model in restrict_models:
copy_credentials = self.system_configuration.credentials.copy()
@ -611,6 +804,13 @@ class ProviderConfiguration(BaseModel):
if custom_model_schema.model_type not in model_types:
continue
status = ModelStatus.ACTIVE
if (custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
@ -621,7 +821,7 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
status=status
)
)
@ -632,16 +832,20 @@ class ProviderConfiguration(BaseModel):
m.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid:
m.status = ModelStatus.QUOTA_EXCEEDED
return provider_models
def _get_custom_provider_models(self,
model_types: list[ModelType],
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
provider_instance: ModelProvider,
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
-> list[ModelWithProviderEntity]:
"""
Get custom provider models.
:param model_types: model types
:param provider_instance: provider instance
:param model_setting_map: model setting map
:return:
"""
provider_models = []
@ -656,6 +860,16 @@ class ProviderConfiguration(BaseModel):
models = provider_instance.models(model_type)
for m in models:
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
load_balancing_enabled = False
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
model_setting = model_setting_map[m.model_type][m.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
if len(model_setting.load_balancing_configs) > 1:
load_balancing_enabled = True
provider_models.append(
ModelWithProviderEntity(
model=m.model,
@ -666,7 +880,8 @@ class ProviderConfiguration(BaseModel):
model_properties=m.model_properties,
deprecated=m.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
status=status,
load_balancing_enabled=load_balancing_enabled
)
)
@ -690,6 +905,17 @@ class ProviderConfiguration(BaseModel):
if not custom_model_schema:
continue
status = ModelStatus.ACTIVE
load_balancing_enabled = False
if (custom_model_schema.model_type in model_setting_map
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
if model_setting.enabled is False:
status = ModelStatus.DISABLED
if len(model_setting.load_balancing_configs) > 1:
load_balancing_enabled = True
provider_models.append(
ModelWithProviderEntity(
model=custom_model_schema.model,
@ -700,7 +926,8 @@ class ProviderConfiguration(BaseModel):
model_properties=custom_model_schema.model_properties,
deprecated=custom_model_schema.deprecated,
provider=SimpleModelProviderEntity(self.provider),
status=ModelStatus.ACTIVE
status=status,
load_balancing_enabled=load_balancing_enabled
)
)

View File

@ -72,3 +72,22 @@ class CustomConfiguration(BaseModel):
"""
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
class ModelLoadBalancingConfiguration(BaseModel):
"""
Class for model load balancing configuration.
"""
id: str
name: str
credentials: dict
class ModelSettings(BaseModel):
"""
Model class for model settings.
"""
model: str
model_type: ModelType
enabled: bool = True
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []

View File

@ -7,7 +7,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.utils.position_helper import sort_to_dict_by_position_map
from core.helper.position_helper import sort_to_dict_by_position_map
class ExtensionModule(enum.Enum):

View File

@ -42,6 +42,8 @@ class MessageFileParser:
raise ValueError('Invalid file url')
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
raise ValueError('Missing file upload_file_id')
if file.get('transform_method') == FileTransferMethod.TOOL_FILE.value and not file.get('tool_file_id'):
raise ValueError('Missing file tool_file_id')
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_extra_config)
@ -149,12 +151,21 @@ class MessageFileParser:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
if transfer_method != FileTransferMethod.TOOL_FILE:
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
transfer_method=transfer_method,
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=file_extra_config
)
return FileVar(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
transfer_method=transfer_method,
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
url=None,
related_id=file.get('tool_file_id'),
extra_config=file_extra_config
)
else:

View File

@ -77,4 +77,4 @@ class UploadFileParser:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= 300 # expired after 5 minutes
return current_time - int(timestamp) <= current_app.config.get('FILES_ACCESS_TIMEOUT')

View File

@ -12,7 +12,7 @@ from config import get_env
from core.helper.code_executor.entities import CodeDependency
from core.helper.code_executor.javascript.javascript_transformer import NodeJsTemplateTransformer
from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTransformer
from core.helper.code_executor.python3.python3_transformer import PYTHON_STANDARD_PACKAGES, Python3TemplateTransformer
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__)
@ -99,7 +99,9 @@ class CodeExecutor:
except CodeExecutionException as e:
raise e
except Exception as e:
raise CodeExecutionException('Failed to execute code, this is likely a network issue, please check if the sandbox service is running')
raise CodeExecutionException('Failed to execute code, which is likely a network issue,'
' please check if the sandbox service is running.'
f' ( Error: {str(e)} )')
try:
response = response.json()
@ -187,7 +189,8 @@ class CodeExecutor:
response = response.json()
dependencies = response.get('data', {}).get('dependencies', [])
return [
CodeDependency(**dependency) for dependency in dependencies if dependency.get('name') not in PYTHON_STANDARD_PACKAGES
CodeDependency(**dependency) for dependency in dependencies
if dependency.get('name') not in Python3TemplateTransformer.get_standard_packages()
]
except Exception as e:
logger.exception(f'Failed to list dependencies: {e}')

View File

@ -1,58 +1,25 @@
import json
import re
from typing import Optional
from textwrap import dedent
from core.helper.code_executor.entities import CodeDependency
from core.helper.code_executor.template_transformer import TemplateTransformer
NODEJS_RUNNER = """// declare main function here
{{code}}
// execute main function, and return the result
// inputs is a dict, unstructured inputs
output = main({{inputs}})
// convert output to json and print
output = JSON.stringify(output)
result = `<<RESULT>>${output}<<RESULT>>`
console.log(result)
"""
NODEJS_PRELOAD = """"""
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
def transform_caller(cls, code: str, inputs: dict,
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
"""
Transform code to python runner
:param code: code
:param inputs: inputs
:return:
"""
# transform inputs to json string
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
# replace code and inputs
runner = NODEJS_RUNNER.replace('{{code}}', code)
runner = runner.replace('{{inputs}}', inputs_str)
return runner, NODEJS_PRELOAD, []
@classmethod
def transform_response(cls, response: str) -> dict:
"""
Transform response to dict
:param response: response
:return:
"""
# extract result
result = re.search(r'<<RESULT>>(.*)<<RESULT>>', response, re.DOTALL)
if not result:
raise ValueError('Failed to parse result')
result = result.group(1)
return json.loads(result)
def get_runner_script(cls) -> str:
runner_script = dedent(
f"""
// declare main function
{cls._code_placeholder}
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
// execute main function
var output_obj = main(inputs_obj)
// convert output to json and print
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>${{output_json}}<<RESULT>>`
console.log(result)
""")
return runner_script

View File

@ -1,94 +1,13 @@
import json
import re
from base64 import b64encode
from typing import Optional
from textwrap import dedent
from core.helper.code_executor.entities import CodeDependency
from core.helper.code_executor.python3.python3_transformer import PYTHON_STANDARD_PACKAGES
from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer
from core.helper.code_executor.template_transformer import TemplateTransformer
PYTHON_RUNNER = """
import jinja2
from json import loads
from base64 import b64decode
template = jinja2.Template('''{{code}}''')
def main(**inputs):
return template.render(**inputs)
# execute main function, and return the result
inputs = b64decode('{{inputs}}').decode('utf-8')
output = main(**loads(inputs))
result = f'''<<RESULT>>{output}<<RESULT>>'''
print(result)
"""
JINJA2_PRELOAD_TEMPLATE = """{% set fruits = ['Apple'] %}
{{ 'a' }}
{% for fruit in fruits %}
<li>{{ fruit }}</li>
{% endfor %}
{% if fruits|length > 1 %}
1
{% endif %}
{% for i in range(5) %}
{% if i == 3 %}{{ i }}{% else %}{% endif %}
{% endfor %}
{% for i in range(3) %}
{{ i + 1 }}
{% endfor %}
{% macro say_hello() %}a{{ 'b' }}{% endmacro %}
{{ s }}{{ say_hello() }}"""
JINJA2_PRELOAD = f"""
import jinja2
from base64 import b64decode
def _jinja2_preload_():
# prepare jinja2 environment, load template and render before to avoid sandbox issue
template = jinja2.Template('''{JINJA2_PRELOAD_TEMPLATE}''')
template.render(s='a')
if __name__ == '__main__':
_jinja2_preload_()
"""
class Jinja2TemplateTransformer(TemplateTransformer):
@classmethod
def transform_caller(cls, code: str, inputs: dict,
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
"""
Transform code to python runner
:param code: code
:param inputs: inputs
:return:
"""
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
# transform jinja2 template to python code
runner = PYTHON_RUNNER.replace('{{code}}', code)
runner = runner.replace('{{inputs}}', inputs_str)
if not dependencies:
dependencies = []
# add native packages and jinja2
for package in PYTHON_STANDARD_PACKAGES.union(['jinja2']):
dependencies.append(CodeDependency(name=package, version=''))
# deduplicate
dependencies = list({
dep.name: dep for dep in dependencies if dep.name
}.values())
return runner, JINJA2_PRELOAD, dependencies
def get_standard_packages(cls) -> set[str]:
return {'jinja2'} | Python3TemplateTransformer.get_standard_packages()
@classmethod
def transform_response(cls, response: str) -> dict:
@ -97,12 +16,49 @@ class Jinja2TemplateTransformer(TemplateTransformer):
:param response: response
:return:
"""
# extract result
result = re.search(r'<<RESULT>>(.*)<<RESULT>>', response, re.DOTALL)
if not result:
raise ValueError('Failed to parse result')
result = result.group(1)
return {
'result': result
'result': cls.extract_result_str_from_response(response)
}
@classmethod
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
# declare main function
def main(**inputs):
import jinja2
template = jinja2.Template('''{cls._code_placeholder}''')
return template.render(**inputs)
import json
from base64 import b64decode
# decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
# execute main function
output = main(**inputs_obj)
# convert output and print
result = f'''<<RESULT>>{{output}}<<RESULT>>'''
print(result)
""")
return runner_script
@classmethod
def get_preload_script(cls) -> str:
preload_script = dedent("""
import jinja2
from base64 import b64decode
def _jinja2_preload_():
# prepare jinja2 environment, load template and render before to avoid sandbox issue
template = jinja2.Template('{{s}}')
template.render(s='a')
if __name__ == '__main__':
_jinja2_preload_()
""")
return preload_script

View File

@ -1,83 +1,51 @@
import json
import re
from base64 import b64encode
from textwrap import dedent
from typing import Optional
from core.helper.code_executor.entities import CodeDependency
from core.helper.code_executor.template_transformer import TemplateTransformer
PYTHON_RUNNER = dedent("""
# declare main function here
{{code}}
from json import loads, dumps
from base64 import b64decode
# execute main function, and return the result
# inputs is a dict, and it
inputs = b64decode('{{inputs}}').decode('utf-8')
output = main(**json.loads(inputs))
# convert output to json and print
output = dumps(output, indent=4)
result = f'''<<RESULT>>
{output}
<<RESULT>>'''
print(result)
""")
PYTHON_PRELOAD = """"""
PYTHON_STANDARD_PACKAGES = {
'json', 'datetime', 'math', 'random', 're', 'string', 'sys', 'time', 'traceback', 'uuid', 'os', 'base64',
'hashlib', 'hmac', 'binascii', 'collections', 'functools', 'operator', 'itertools', 'uuid',
}
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
def transform_caller(cls, code: str, inputs: dict,
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
"""
Transform code to python runner
:param code: code
:param inputs: inputs
:return:
"""
# transform inputs to json string
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
def get_standard_packages(cls) -> set[str]:
return {
'base64',
'binascii',
'collections',
'datetime',
'functools',
'hashlib',
'hmac',
'itertools',
'json',
'math',
'operator',
'os',
'random',
're',
'string',
'sys',
'time',
'traceback',
'uuid',
}
# replace code and inputs
runner = PYTHON_RUNNER.replace('{{code}}', code)
runner = runner.replace('{{inputs}}', inputs_str)
# add standard packages
if dependencies is None:
dependencies = []
for package in PYTHON_STANDARD_PACKAGES:
if package not in dependencies:
dependencies.append(CodeDependency(name=package, version=''))
# deduplicate
dependencies = list({dep.name: dep for dep in dependencies if dep.name}.values())
return runner, PYTHON_PRELOAD, dependencies
@classmethod
def transform_response(cls, response: str) -> dict:
"""
Transform response to dict
:param response: response
:return:
"""
# extract result
result = re.search(r'<<RESULT>>(.*?)<<RESULT>>', response, re.DOTALL)
if not result:
raise ValueError('Failed to parse result')
result = result.group(1)
return json.loads(result)
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
# declare main function
{cls._code_placeholder}
import json
from base64 import b64decode
# decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
# execute main function
output_obj = main(**inputs_obj)
# convert output to json and print
output_json = json.dumps(output_obj, indent=4)
result = f'''<<RESULT>>{{output_json}}<<RESULT>>'''
print(result)
""")
return runner_script

View File

@ -1,13 +1,25 @@
import json
import re
from abc import ABC, abstractmethod
from base64 import b64encode
from typing import Optional
from pydantic import BaseModel
from core.helper.code_executor.entities import CodeDependency
class TemplateTransformer(ABC):
class TemplateTransformer(ABC, BaseModel):
_code_placeholder: str = '{{code}}'
_inputs_placeholder: str = '{{inputs}}'
_result_tag: str = '<<RESULT>>'
@classmethod
@abstractmethod
def transform_caller(cls, code: str, inputs: dict,
def get_standard_packages(cls) -> set[str]:
return set()
@classmethod
def transform_caller(cls, code: str, inputs: dict,
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
"""
Transform code to python runner
@ -15,14 +27,61 @@ class TemplateTransformer(ABC):
:param inputs: inputs
:return: runner, preload
"""
pass
runner_script = cls.assemble_runner_script(code, inputs)
preload_script = cls.get_preload_script()
packages = dependencies or []
standard_packages = cls.get_standard_packages()
for package in standard_packages:
if package not in packages:
packages.append(CodeDependency(name=package, version=''))
packages = list({dep.name: dep for dep in packages if dep.name}.values())
return runner_script, preload_script, packages
@classmethod
def extract_result_str_from_response(cls, response: str) -> str:
result = re.search(rf'{cls._result_tag}(.*){cls._result_tag}', response, re.DOTALL)
if not result:
raise ValueError('Failed to parse result')
result = result.group(1)
return result
@classmethod
@abstractmethod
def transform_response(cls, response: str) -> dict:
"""
Transform response to dict
:param response: response
:return:
"""
pass
return json.loads(cls.extract_result_str_from_response(response))
@classmethod
@abstractmethod
def get_runner_script(cls) -> str:
"""
Get runner script
"""
pass
@classmethod
def serialize_inputs(cls, inputs: dict) -> str:
inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode('utf-8')
return input_base64_encoded
@classmethod
def assemble_runner_script(cls, code: str, inputs: dict) -> str:
# assemble runner script
script = cls.get_runner_script()
script = script.replace(cls._code_placeholder, code)
inputs_str = cls.serialize_inputs(inputs)
script = script.replace(cls._inputs_placeholder, inputs_str)
return script
@classmethod
def get_preload_script(cls) -> str:
"""
Get preload script
"""
return ''

View File

@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client
class ProviderCredentialsCacheType(Enum):
PROVIDER = "provider"
MODEL = "provider_model"
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
class ProviderCredentialsCache:

View File

@ -1,10 +1,9 @@
import logging
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr
import yaml
from core.tools.utils.yaml_utils import load_yaml_file
def get_position_map(
@ -17,21 +16,15 @@ def get_position_map(
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
try:
position_file_name = os.path.join(folder_path, file_name)
if not os.path.exists(position_file_name):
return {}
with open(position_file_name, encoding='utf-8') as f:
positions = yaml.safe_load(f)
position_map = {}
for index, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
return position_map
except:
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
return {}
position_file_name = os.path.join(folder_path, file_name)
positions = load_yaml_file(position_file_name, ignore_error=True)
position_map = {}
index = 0
for _, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
index += 1
return position_map
def sort_by_position_map(

View File

@ -12,7 +12,6 @@ from flask import Flask, current_app
from flask_login import current_user
from sqlalchemy.orm.exc import ObjectDeletedError
from core.docstore.dataset_docstore import DatasetDocumentStore
from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
@ -20,12 +19,16 @@ from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
from core.splitter.text_splitter import TextSplitter
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@ -283,11 +286,7 @@ class IndexingRunner:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
tokens += embedding_model_type_instance.get_num_tokens(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
tokens += embedding_model_instance.get_text_embedding_num_tokens(
texts=[self.filter_string(document.page_content)]
)
@ -428,7 +427,7 @@ class IndexingRunner:
chunk_size=segmentation["max_tokens"],
chunk_overlap=chunk_overlap,
fixed_separator=separator,
separators=["\n\n", "", ".", " ", ""],
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
)
else:
@ -436,7 +435,7 @@ class IndexingRunner:
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
separators=["\n\n", "", ".", " ", ""],
separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
)
@ -655,10 +654,6 @@ class IndexingRunner:
tokens = 0
chunk_size = 10
embedding_model_type_instance = None
if embedding_model_instance:
embedding_model_type_instance = embedding_model_instance.model_type_instance
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
# create keyword index
create_keyword_thread = threading.Thread(target=self._process_keyword_index,
args=(current_app._get_current_object(),
@ -671,8 +666,7 @@ class IndexingRunner:
chunk_documents = documents[i:i + chunk_size]
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
chunk_documents, dataset,
dataset_document, embedding_model_instance,
embedding_model_type_instance))
dataset_document, embedding_model_instance))
for future in futures:
tokens += future.result()
@ -713,7 +707,7 @@ class IndexingRunner:
db.session.commit()
def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
embedding_model_instance, embedding_model_type_instance):
embedding_model_instance):
with flask_app.app_context():
# check document is paused
self._check_document_paused_status(dataset_document.id)
@ -721,9 +715,7 @@ class IndexingRunner:
tokens = 0
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
tokens += sum(
embedding_model_type_instance.get_num_tokens(
embedding_model_instance.model,
embedding_model_instance.credentials,
embedding_model_instance.get_text_embedding_num_tokens(
[document.page_content]
)
for document in chunk_documents

View File

@ -1,3 +1,5 @@
from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file.message_file_parser import MessageFileParser
from core.model_manager import ModelInstance
@ -9,8 +11,6 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import model_provider_factory
from extensions.ext_database import db
from models.model import AppMode, Conversation, Message
@ -21,7 +21,7 @@ class TokenBufferMemory:
self.model_instance = model_instance
def get_history_prompt_messages(self, max_token_limit: int = 2000,
message_limit: int = 10) -> list[PromptMessage]:
message_limit: Optional[int] = None) -> list[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
@ -30,10 +30,15 @@ class TokenBufferMemory:
app_record = self.conversation.app
# fetch limited messages, and return reversed
messages = db.session.query(Message).filter(
query = db.session.query(Message).filter(
Message.conversation_id == self.conversation.id,
Message.answer != ''
).order_by(Message.created_at.desc()).limit(message_limit).all()
).order_by(Message.created_at.desc())
if message_limit and message_limit > 0:
messages = query.limit(message_limit).all()
else:
messages = query.all()
messages = list(reversed(messages))
message_file_parser = MessageFileParser(
@ -78,12 +83,7 @@ class TokenBufferMemory:
return []
# prune the chat message if it exceeds the max token limit
provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
curr_message_tokens = model_type_instance.get_num_tokens(
self.model_instance.model,
self.model_instance.credentials,
curr_message_tokens = self.model_instance.get_llm_num_tokens(
prompt_messages
)
@ -91,9 +91,7 @@ class TokenBufferMemory:
pruned_memory = []
while curr_message_tokens > max_token_limit and prompt_messages:
pruned_memory.append(prompt_messages.pop(0))
curr_message_tokens = model_type_instance.get_num_tokens(
self.model_instance.model,
self.model_instance.credentials,
curr_message_tokens = self.model_instance.get_llm_num_tokens(
prompt_messages
)
@ -102,7 +100,7 @@ class TokenBufferMemory:
def get_history_prompt_text(self, human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int = 10) -> str:
message_limit: Optional[int] = None) -> str:
"""
Get history prompt text.
:param human_prefix: human prefix

View File

@ -1,7 +1,10 @@
import logging
import os
from collections.abc import Generator
from typing import IO, Optional, Union, cast
from core.entities.provider_configuration import ProviderModelBundle
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult
@ -9,6 +12,7 @@ from core.model_runtime.entities.message_entities import PromptMessage, PromptMe
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
@ -16,6 +20,10 @@ from core.model_runtime.model_providers.__base.speech2text_model import Speech2T
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client
from models.provider import ProviderType
logger = logging.getLogger(__name__)
class ModelInstance:
@ -29,6 +37,12 @@ class ModelInstance:
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self.provider_model_bundle.model_type_instance
self.load_balancing_manager = self._get_load_balancing_manager(
configuration=provider_model_bundle.configuration,
model_type=provider_model_bundle.model_type_instance.model_type,
model=model,
credentials=self.credentials
)
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
"""
@ -37,8 +51,10 @@ class ModelInstance:
:param model: model name
:return:
"""
credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=provider_model_bundle.model_type_instance.model_type,
configuration = provider_model_bundle.configuration
model_type = provider_model_bundle.model_type_instance.model_type
credentials = configuration.get_current_credentials(
model_type=model_type,
model=model
)
@ -47,6 +63,43 @@ class ModelInstance:
return credentials
def _get_load_balancing_manager(self, configuration: ProviderConfiguration,
model_type: ModelType,
model: str,
credentials: dict) -> Optional["LBModelManager"]:
"""
Get load balancing model credentials
:param configuration: provider configuration
:param model_type: model type
:param model: model name
:param credentials: model credentials
:return:
"""
if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM:
current_model_setting = None
# check if model is disabled by admin
for model_setting in configuration.model_settings:
if (model_setting.model_type == model_type
and model_setting.model == model):
current_model_setting = model_setting
break
# check if load balancing is enabled
if current_model_setting and current_model_setting.load_balancing_configs:
# use load balancing proxy to choose credentials
lb_model_manager = LBModelManager(
tenant_id=configuration.tenant_id,
provider=configuration.provider.provider,
model_type=model_type,
model=model,
load_balancing_configs=current_model_setting.load_balancing_configs,
managed_credentials=credentials if configuration.custom_configuration.provider else None
)
return lb_model_manager
return None
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
@ -67,7 +120,8 @@ class ModelInstance:
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self.model_type_instance.invoke(
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
@ -79,6 +133,27 @@ class ModelInstance:
callbacks=callbacks
)
def get_llm_num_tokens(self, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for llm
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools
)
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
@ -92,13 +167,32 @@ class ModelInstance:
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self.model_type_instance.invoke(
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
texts=texts,
user=user
)
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
"""
Get number of tokens for text embedding
:param texts: texts to embed
:return:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
credentials=self.credentials,
texts=texts
)
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None) \
@ -117,7 +211,8 @@ class ModelInstance:
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return self.model_type_instance.invoke(
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
query=query,
@ -140,7 +235,8 @@ class ModelInstance:
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return self.model_type_instance.invoke(
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
text=text,
@ -160,7 +256,8 @@ class ModelInstance:
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return self.model_type_instance.invoke(
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
file=file,
@ -183,7 +280,8 @@ class ModelInstance:
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.invoke(
return self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
credentials=self.credentials,
content_text=content_text,
@ -193,7 +291,44 @@ class ModelInstance:
streaming=streaming
)
def get_tts_voices(self, language: str) -> list:
def _round_robin_invoke(self, function: callable, *args, **kwargs):
"""
Round-robin invoke
:param function: function to invoke
:param args: function args
:param kwargs: function kwargs
:return:
"""
if not self.load_balancing_manager:
return function(*args, **kwargs)
last_exception = None
while True:
lb_config = self.load_balancing_manager.fetch_next()
if not lb_config:
if not last_exception:
raise ProviderTokenNotInitError("Model credentials is not initialized.")
else:
raise last_exception
try:
if 'credentials' in kwargs:
del kwargs['credentials']
return function(*args, **kwargs, credentials=lb_config.credentials)
except InvokeRateLimitError as e:
# expire in 60 seconds
self.load_balancing_manager.cooldown(lb_config, expire=60)
last_exception = e
continue
except (InvokeAuthorizationError, InvokeConnectionError) as e:
# expire in 10 seconds
self.load_balancing_manager.cooldown(lb_config, expire=10)
last_exception = e
continue
except Exception as e:
raise e
def get_tts_voices(self, language: Optional[str] = None) -> list:
"""
Invoke large language tts model voices
@ -226,6 +361,7 @@ class ModelManager:
"""
if not provider:
return self.get_default_model_instance(tenant_id, model_type)
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=provider,
@ -255,3 +391,141 @@ class ModelManager:
model_type=model_type,
model=default_model_entity.model
)
class LBModelManager:
def __init__(self, tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: Optional[dict] = None) -> None:
"""
Load balancing model manager
:param load_balancing_configs: all load balancing configurations
:param managed_credentials: credentials if load balancing configuration name is __inherit__
"""
self._tenant_id = tenant_id
self._provider = provider
self._model_type = model_type
self._model = model
self._load_balancing_configs = load_balancing_configs
for load_balancing_config in self._load_balancing_configs:
if load_balancing_config.name == "__inherit__":
if not managed_credentials:
# remove __inherit__ if managed credentials is not provided
self._load_balancing_configs.remove(load_balancing_config)
else:
load_balancing_config.credentials = managed_credentials
def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]:
"""
Get next model load balancing config
Strategy: Round Robin
:return:
"""
cache_key = "model_lb_index:{}:{}:{}:{}".format(
self._tenant_id,
self._provider,
self._model_type.value,
self._model
)
cooldown_load_balancing_configs = []
max_index = len(self._load_balancing_configs)
while True:
current_index = redis_client.incr(cache_key)
if current_index >= 10000000:
current_index = 1
redis_client.set(cache_key, current_index)
redis_client.expire(cache_key, 3600)
if current_index > max_index:
current_index = current_index % max_index
real_index = current_index - 1
if real_index > max_index:
real_index = 0
config = self._load_balancing_configs[real_index]
if self.in_cooldown(config):
cooldown_load_balancing_configs.append(config)
if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
# all configs are in cooldown
return None
continue
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n"
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
f"model_type: {self._model_type.value}\nmodel: {self._model}")
return config
return None
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
"""
Cooldown model load balancing config
:param config: model load balancing config
:param expire: cooldown time
:return:
"""
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
self._tenant_id,
self._provider,
self._model_type.value,
self._model,
config.id
)
redis_client.setex(cooldown_cache_key, expire, 'true')
def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
"""
Check if model load balancing config is in cooldown
:param config: model load balancing config
:return:
"""
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
self._tenant_id,
self._provider,
self._model_type.value,
self._model,
config.id
)
return redis_client.exists(cooldown_cache_key)
@classmethod
def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
provider: str,
model_type: ModelType,
model: str,
config_id: str) -> tuple[bool, int]:
"""
Get model load balancing config is in cooldown and ttl
:param tenant_id: workspace id
:param provider: provider name
:param model_type: model type
:param model: model name
:param config_id: model load balancing config id
:return:
"""
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
tenant_id,
provider,
model_type.value,
model,
config_id
)
ttl = redis_client.ttl(cooldown_cache_key)
if ttl == -2:
return False, 0
return True, ttl

View File

@ -20,7 +20,7 @@ This module provides the interface for invoking and authenticating various model
![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png)
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./schema.md).
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
- Selectable model list display

View File

@ -1,4 +1,3 @@
from abc import ABC
from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -14,7 +13,7 @@ _TEXT_COLOR_MAPPING = {
}
class Callback(ABC):
class Callback:
"""
Base class for callbacks.
Only for LLM.

View File

@ -336,7 +336,7 @@ Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement
- Invoke Invocation
```python
def _invoke(elf, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model

View File

@ -376,7 +376,7 @@ class XinferenceProvider(Provider):
- Invoke 调用
```python
def _invoke(elf, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
"""
Invoke large language model

View File

@ -3,8 +3,7 @@ import os
from abc import ABC, abstractmethod
from typing import Optional
import yaml
from core.helper.position_helper import get_position_map, sort_by_position_map
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
@ -18,7 +17,7 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.utils.position_helper import get_position_map, sort_by_position_map
from core.tools.utils.yaml_utils import load_yaml_file
class AIModel(ABC):
@ -154,8 +153,7 @@ class AIModel(ABC):
# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
# read yaml data from yaml file
with open(model_schema_yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(model_schema_yaml_path, ignore_error=True)
new_parameter_rules = []
for parameter_rule in yaml_data.get('parameter_rules', []):

View File

@ -1,12 +1,11 @@
import os
from abc import ABC, abstractmethod
import yaml
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
from core.tools.utils.yaml_utils import load_yaml_file
class ModelProvider(ABC):
@ -44,10 +43,7 @@ class ModelProvider(ABC):
# read provider schema from yaml file
yaml_path = os.path.join(current_path, f'{provider_name}.yaml')
yaml_data = {}
if os.path.exists(yaml_path):
with open(yaml_path, encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
yaml_data = load_yaml_file(yaml_path, ignore_error=True)
try:
# yaml_data to entity

View File

@ -2,7 +2,9 @@
- anthropic
- azure_openai
- google
- vertex_ai
- nvidia
- nvidia_nim
- cohere
- bedrock
- togetherai
@ -29,3 +31,5 @@
- volcengine_maas
- openai_api_compatible
- deepseek
- hunyuan
- siliconflow

View File

@ -49,7 +49,7 @@ LLM_BASE_MODELS = [
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 4096,
ModelPropertyKey.CONTEXT_SIZE: 16385,
},
parameter_rules=[
ParameterRule(
@ -68,11 +68,25 @@ LLM_BASE_MODELS = [
name='frequency_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096)
_get_max_tokens(default=512, min_val=1, max_val=4096),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
),
required=False,
options=['text', 'json_object']
),
],
pricing=PriceConfig(
input=0.001,
output=0.002,
input=0.0005,
output=0.0015,
unit=0.001,
currency='USD',
)
@ -482,6 +496,158 @@ LLM_BASE_MODELS = [
)
)
),
AzureBaseModel(
base_model_name='gpt-4o',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label',
),
model_type=ModelType.LLM,
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.VISION,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name='temperature',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
ParameterRule(
name='seed',
label=I18nObject(
zh_Hans='种子',
en_US='Seed'
),
type='int',
help=I18nObject(
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
),
required=False,
precision=2,
min=0,
max=1,
),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
),
required=False,
options=['text', 'json_object']
),
],
pricing=PriceConfig(
input=5.00,
output=15.00,
unit=0.000001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='gpt-4o-2024-05-13',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label',
),
model_type=ModelType.LLM,
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.VISION,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name='temperature',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
ParameterRule(
name='seed',
label=I18nObject(
zh_Hans='种子',
en_US='Seed'
),
type='int',
help=I18nObject(
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
),
required=False,
precision=2,
min=0,
max=1,
),
ParameterRule(
name='response_format',
label=I18nObject(
zh_Hans='回复格式',
en_US='response_format'
),
type='string',
help=I18nObject(
zh_Hans='指定模型必须输出的格式',
en_US='specifying the format that the model must output'
),
required=False,
options=['text', 'json_object']
),
],
pricing=PriceConfig(
input=5.00,
output=15.00,
unit=0.000001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='gpt-4-turbo',
entity=AIModelEntity(
@ -551,8 +717,8 @@ LLM_BASE_MODELS = [
),
],
pricing=PriceConfig(
input=0.001,
output=0.003,
input=0.01,
output=0.03,
unit=0.001,
currency='USD',
)
@ -627,8 +793,8 @@ LLM_BASE_MODELS = [
),
],
pricing=PriceConfig(
input=0.001,
output=0.003,
input=0.01,
output=0.03,
unit=0.001,
currency='USD',
)

View File

@ -53,12 +53,24 @@ model_credential_schema:
type: select
required: true
options:
- label:
en_US: 2024-05-01-preview
value: 2024-05-01-preview
- label:
en_US: 2024-04-01-preview
value: 2024-04-01-preview
- label:
en_US: 2024-03-01-preview
value: 2024-03-01-preview
- label:
en_US: 2024-02-15-preview
value: 2024-02-15-preview
- label:
en_US: 2023-12-01-preview
value: 2023-12-01-preview
- label:
en_US: '2024-02-01'
value: '2024-02-01'
placeholder:
zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here
@ -99,6 +111,18 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o
value: gpt-4o
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-05-13
value: gpt-4o-2024-05-13
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo
value: gpt-4-turbo

View File

@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: chat
context_size: 4000
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -6,7 +6,7 @@ features:
- agent-thought
model_properties:
mode: chat
context_size: 192000
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature

View File

@ -0,0 +1,45 @@
model: baichuan3-turbo-128k
label:
en_US: Baichuan3-Turbo-128k
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
min: 1
max: 128000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
- name: with_search_enhance
label:
zh_Hans: 搜索增强
en_US: Search Enhance
type: boolean
help:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false

View File

@ -0,0 +1,45 @@
model: baichuan3-turbo
label:
en_US: Baichuan3-Turbo
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
min: 1
max: 32000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
- name: with_search_enhance
label:
zh_Hans: 搜索增强
en_US: Search Enhance
type: boolean
help:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false

View File

@ -0,0 +1,45 @@
model: baichuan4
label:
en_US: Baichuan4
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8000
min: 1
max: 32000
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
default: 1
min: 1
max: 2
- name: with_search_enhance
label:
zh_Hans: 搜索增强
en_US: Search Enhance
type: boolean
help:
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
en_US: Allow the model to perform external search to enhance the generation results.
required: false

View File

@ -51,26 +51,29 @@ class BaichuanModel:
'baichuan2-turbo': 'Baichuan2-Turbo',
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
'baichuan2-53b': 'Baichuan2-53B',
'baichuan3-turbo': 'Baichuan3-Turbo',
'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
'baichuan4': 'Baichuan4',
}[model]
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
resp = response.json()
choices = resp.get('choices', [])
message = BaichuanMessage(content='', role='assistant')
for choice in choices:
message.content += choice['message']['content']
message.role = choice['message']['role']
if choice['finish_reason']:
message.stop_reason = choice['finish_reason']
resp = response.json()
choices = resp.get('choices', [])
message = BaichuanMessage(content='', role='assistant')
for choice in choices:
message.content += choice['message']['content']
message.role = choice['message']['role']
if choice['finish_reason']:
message.stop_reason = choice['finish_reason']
if 'usage' in resp:
message.usage = {
'prompt_tokens': resp['usage']['prompt_tokens'],
'completion_tokens': resp['usage']['completion_tokens'],
'total_tokens': resp['usage']['total_tokens'],
}
return message
if 'usage' in resp:
message.usage = {
'prompt_tokens': resp['usage']['prompt_tokens'],
'completion_tokens': resp['usage']['completion_tokens'],
'total_tokens': resp['usage']['total_tokens'],
}
return message
def _handle_chat_stream_generate_response(self, response) -> Generator:
for line in response.iter_lines():
@ -110,7 +113,8 @@ class BaichuanModel:
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
parameters: dict[str, Any]) \
-> dict[str, Any]:
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
prompt_messages = []
for message in messages:
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
@ -143,7 +147,8 @@ class BaichuanModel:
raise BadRequestError(f"Unknown model: {model}")
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
# there is no secret key for turbo api
return {
'Content-Type': 'application/json',
@ -160,7 +165,8 @@ class BaichuanModel:
parameters: dict[str, Any], timeout: int) \
-> Union[Generator, BaichuanMessage]:
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
else:
raise BadRequestError(f"Unknown model: {model}")

View File

@ -7,6 +7,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
@ -32,20 +33,21 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
class BaichuanLarguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
return self._num_tokens_from_messages(prompt_messages)
def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int:
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
"""Calculate num tokens for baichuan model"""
def tokens(text: str):
return BaichuanTokenizer._get_num_tokens(text)
@ -85,9 +87,20 @@ class BaichuanLarguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, ToolPromptMessage):
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": message.content
}]
}
else:
raise ValueError(f"Unknown message type {type(message)}")
return message_dict
def validate_credentials(self, model: str, credentials: dict) -> None:
@ -106,13 +119,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
except Exception as e:
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
if tools is not None and len(tools) > 0:
raise InvokeBadRequestError("Baichuan model doesn't support tools")
instance = BaichuanModel(
api_key=credentials['api_key'],
secret_key=credentials.get('secret_key', '')
@ -129,11 +142,12 @@ class BaichuanLarguageModel(LargeLanguageModel):
]
# invoke model
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60)
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
timeout=60)
if stream:
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
def _handle_chat_generate_response(self, model: str,
@ -141,7 +155,9 @@ class BaichuanLarguageModel(LargeLanguageModel):
credentials: dict,
response: BaichuanMessage) -> LLMResult:
# convert baichuan message to llm result
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens'])
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=response.usage['prompt_tokens'],
completion_tokens=response.usage['completion_tokens'])
return LLMResult(
model=model,
prompt_messages=prompt_messages,
@ -158,7 +174,9 @@ class BaichuanLarguageModel(LargeLanguageModel):
response: Generator[BaichuanMessage, None, None]) -> Generator:
for message in response:
if message.usage:
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens'])
usage = self._calc_response_usage(model=model, credentials=credentials,
prompt_tokens=message.usage['prompt_tokens'],
completion_tokens=message.usage['completion_tokens'])
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,

View File

@ -12,6 +12,7 @@
- meta.llama3-70b-instruct-v1:0
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- mistral.mistral-small-2402-v1:0
- mistral.mistral-large-2402-v1:0
- mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-7b-instruct-v0:2

View File

@ -51,7 +51,7 @@ parameter_rules:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.003'
output: '0.015'
input: '0.00025'
output: '0.00125'
unit: '0.001'
currency: USD

View File

@ -50,7 +50,7 @@ parameter_rules:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
pricing:
input: '0.00025'
output: '0.00125'
input: '0.003'
output: '0.015'
unit: '0.001'
currency: USD

View File

@ -358,26 +358,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
return message_dict
def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param messages: prompt messages or message string
:param prompt_messages: prompt messages or message string
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prefix = model.split('.')[0]
model_name = model.split('.')[1]
if isinstance(messages, str):
prompt = messages
if isinstance(prompt_messages, str):
prompt = prompt_messages
else:
prompt = self._convert_messages_to_prompt(messages, prefix, model_name)
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
return self._get_num_tokens_by_gpt2(prompt)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""

View File

@ -0,0 +1,27 @@
model: mistral.mistral-small-2402-v1:0
label:
en_US: Mistral Small
model_type: llm
model_properties:
mode: completion
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
required: false
default: 0.7
- name: top_p
use_template: top_p
required: false
default: 1
- name: max_tokens
use_template: max_tokens
required: true
default: 512
min: 1
max: 4096
pricing:
input: '0.001'
output: '0.03'
unit: '0.001'
currency: USD

View File

@ -1,3 +1,4 @@
- amazon.titan-embed-text-v1
- amazon.titan-embed-text-v2:0
- cohere.embed-english-v3
- cohere.embed-multilingual-v3

View File

@ -4,5 +4,5 @@ model_properties:
context_size: 8192
pricing:
input: '0.0001'
unit: '0.001'
unit: '0.0001'
currency: USD

View File

@ -0,0 +1,8 @@
model: amazon.titan-embed-text-v2:0
model_type: text-embedding
model_properties:
context_size: 8192
pricing:
input: '0.00002'
unit: '0.00001'
currency: USD

View File

@ -59,15 +59,15 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
model_prefix = model.split('.')[0]
if model_prefix == "amazon" :
for text in texts:
body = {
for text in texts:
body = {
"inputText": text,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend([response_body.get('embedding')])
token_usage += response_body.get('inputTextTokenCount')
logger.warning(f'Total Tokens: {token_usage}')
result = TextEmbeddingResult(
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend([response_body.get('embedding')])
token_usage += response_body.get('inputTextTokenCount')
logger.warning(f'Total Tokens: {token_usage}')
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
@ -75,20 +75,20 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
credentials=credentials,
tokens=token_usage
)
)
return result
)
return result
if model_prefix == "cohere" :
input_type = 'search_document' if len(texts) > 1 else 'search_query'
for text in texts:
body = {
input_type = 'search_document' if len(texts) > 1 else 'search_query'
for text in texts:
body = {
"texts": [text],
"input_type": input_type,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend(response_body.get('embeddings'))
token_usage += len(text)
result = TextEmbeddingResult(
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend(response_body.get('embeddings'))
token_usage += len(text)
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
@ -96,9 +96,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
credentials=credentials,
tokens=token_usage
)
)
return result
)
return result
#others
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
@ -183,7 +183,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
)
return usage
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
"""
Map client error to invoke error
@ -212,9 +212,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
content_type = 'application/json'
try:
response = bedrock_runtime.invoke_model(
body=json.dumps(body),
modelId=model,
accept=accept,
body=json.dumps(body),
modelId=model,
accept=accept,
contentType=content_type
)
response_body = json.loads(response.get('body').read().decode('utf-8'))

View File

@ -0,0 +1,39 @@
model: gemini-1.5-flash-latest
label:
en_US: Gemini 1.5 Flash
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens_to_sample
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@ -1,18 +1,22 @@
import base64
import json
import logging
import mimetypes
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
import google.ai.generativelanguage as glm
import google.api_core.exceptions as exceptions
import google.generativeai as genai
import google.generativeai.client as client
import requests
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
from google.generativeai.types.content_types import to_part
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
@ -204,6 +208,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
stream=stream,
safety_settings=safety_settings,
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
request_options={"timeout": 600}
)
if stream:
@ -360,11 +365,22 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content['parts'].append(to_part(c.data))
else:
metadata, data = c.data.split(',', 1)
mime_type = metadata.split(';', 1)[0].split(':')[1]
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(',', 1)
mime_type = metadata.split(';', 1)[0].split(':')[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}}
glm_content['parts'].append(blob)
return glm_content
elif isinstance(message, AssistantPromptMessage):
glm_content = {
@ -443,4 +459,4 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
exceptions.RequestRangeNotSatisfiable,
exceptions.Cancelled,
]
}
}

Some files were not shown because too many files have changed in this diff Show More