Compare commits

...

97 Commits
0.5.7 ... 0.5.9

Author SHA1 Message Date
ce5b19d011 bump version to 0.5.9 (#2794) 2024-03-12 14:01:24 +08:00
f82a64d149 feat: add DingTalk(钉钉) tool for sending message to chat group bot via webhook (#2693) 2024-03-12 13:45:59 +08:00
f49b1afd6c feat:support azure tts (#2751) 2024-03-12 12:06:35 +08:00
796c5626a7 fix delete dataset when dataset has no document (#2789)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-11 23:57:38 +08:00
e54c9cd401 Feat/open ai compatible functioncall (#2783)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-11 19:48:21 +08:00
f8951d7f57 fix: api tool provider not found (#2782) 2024-03-11 18:21:41 +08:00
6454e1d644 chunk-overlap None check (#2781)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-11 15:36:56 +08:00
e184c8cb42 Update README.md (#2780) 2024-03-11 14:53:40 +08:00
fdd211e399 debug/chat: increase notify error duration to 3000 (#2778) 2024-03-11 14:16:31 +08:00
7001e21e7d overview: fix filter today calc start & end (#2777) 2024-03-11 14:11:51 +08:00
82d0732c12 fix: aippt default styles (#2779) 2024-03-11 14:04:09 +08:00
53cd125780 fix: deep copy of model-tool label (#2775) 2024-03-11 10:27:00 +08:00
3c91f9b5ab fix: dataset segements api (#2766) 2024-03-11 09:26:15 +08:00
f073dca22a feat: optimize db connection when llm invoking (#2774) 2024-03-10 15:48:31 +08:00
8b1e35d7dc doc: add suggested questions back (#2771) 2024-03-10 15:40:17 +08:00
b75d8ca621 fix: auto closing when close local image uploading (#2767) 2024-03-10 13:11:41 +08:00
9beefd7d5a fix: auto prompt (#2768) 2024-03-09 18:36:58 +08:00
88145efa97 fix: app name can be empty in settings modal (#2761) 2024-03-09 09:13:12 +08:00
bdc13f9238 SMTP authentication is optional (#2765)
Co-authored-by: Laurent Magnien <laurent.magnien@adsn.fr>
2024-03-09 09:11:03 +08:00
ce58f0607b Feat/tool secret parameter (#2760) 2024-03-08 20:31:13 +08:00
bbc0d330a9 chore: rename lastStep to previousStep (#2759) 2024-03-08 19:27:02 +08:00
60e7e17c86 feat: Add new Azure OpenAI Embedding models (#2758) 2024-03-08 19:04:20 +08:00
237bb8514e replace message content type list to string when file_objs is empty .. (#2745) 2024-03-08 18:46:31 +08:00
bd26c933d2 fix: valid password on reset-password page (#2753) 2024-03-08 18:44:49 +08:00
b6b58da2d2 enhance: custom tool timeout (#2754) 2024-03-08 15:26:08 +08:00
40c646cf7a Feat/model as tool (#2744) 2024-03-08 15:22:55 +08:00
3231a8c51c fix: image tokenizer (#2752) 2024-03-08 14:50:51 +08:00
4170d6a491 use SVG icons for built-in tools (#2748) 2024-03-08 10:21:26 +08:00
0b50c525cf feat: support error correction and border size in qrcode tool (#2731) 2024-03-07 20:54:14 +08:00
8ba38e8e74 fix overlap and splitter optimization (#2742)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-07 18:25:49 +08:00
b163545771 Use python-docx to extract docx files (#2654) 2024-03-07 18:24:55 +08:00
c0b82f8e58 UPDATE: Twilio tool crdential verification (#2741) 2024-03-07 18:08:52 +08:00
b75ff5fa03 fix:missing import (#2739) 2024-03-07 17:31:30 +08:00
9440d7fe88 fix: the behavior of save action in opening config panel (#2736) 2024-03-07 16:48:44 +08:00
24809fce07 fix: missing en_name of aippt (#2737) 2024-03-07 16:37:12 +08:00
9819ad347f feat:support azure whisper model and fix:rename text-embedidng-ada-002.yaml to text-embedding-ada-002.yaml (#2732) 2024-03-07 16:36:58 +08:00
8fe83750b7 Fix/jina tokenizer cache (#2735) 2024-03-07 16:32:37 +08:00
1809f05904 Feat/add groq (#2733) 2024-03-07 16:00:40 +08:00
0ac250a035 fix: check webhook key of Wecom tool in valid UUID form and fix typo (#2719) 2024-03-07 15:51:06 +08:00
405a00bb2c fix:delete the slash at the end of xinference provider server_url (#2730) 2024-03-07 15:37:05 +08:00
3a3ca8e6a9 fix: max tokens can only up to 2048 (#2734) 2024-03-07 15:35:56 +08:00
27e678480e Feat: AIPPT & DynamicToolParamter (#2725) 2024-03-07 15:04:42 +08:00
7052565380 fix typo: responsing -> responding (#2718)
Co-authored-by: OSS-MAOLONGDONG\kaihong <maolongdong@kaihong.com>
2024-03-07 10:20:35 +08:00
31070ffbca fix qa index processor tenant id is None error (#2713)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-06 16:46:08 +08:00
7f3dec7bee fix error msg format issue (#2715)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-06 16:45:40 +08:00
b1e0db4944 fix: chatbot service api auto generate name default value error (#2709) 2024-03-06 13:19:27 +08:00
c439952a41 fix(web): chat input auto resize by window (#2696) 2024-03-06 12:49:22 +08:00
2f28afebb6 FEAT: Add twilio tool for sending text and whatsapp messages (#2700)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-06 11:35:08 +08:00
fa7ba30ba3 Fix rebuild index&csv parsing (#2705)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-06 11:33:32 +08:00
1cf5f510ed feat: add qrcode tool for QR code generation (#2699) 2024-03-06 11:26:16 +08:00
526c874caa fix mistralai icon (#2707) 2024-03-06 11:08:22 +08:00
f88f744097 make volume folders for milvus docker containers ignored by git (#2694) 2024-03-05 17:26:21 +08:00
95733796f0 fix: replace os.path.join with yarl (#2690) 2024-03-05 17:25:20 +08:00
552f319b9d feat: support HTTP response compression in api server (#2680) 2024-03-05 14:45:22 +08:00
38e5952417 Fix/agent react output parser (#2689) 2024-03-05 14:02:07 +08:00
7f891939f1 FEAT: add tavily tool for searching... A search engine for LLM (#2681) 2024-03-05 10:23:44 +08:00
69a5ce1e31 Fix tts play logic (#2683)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-05 09:22:36 +08:00
534802b761 bump version to 0.5.8 (#2685) 2024-03-05 01:37:53 +08:00
5c258e212c feat: add Anthropic claude-3 models support (#2684) 2024-03-05 01:37:42 +08:00
6a6133c102 Fix voice selection (#2664)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-03-04 17:50:06 +08:00
3c1825187a fix: auto generate prompt result not show (#2678) 2024-03-04 17:36:11 +08:00
8523b34be7 add jina-reranker-v1-base-en (#2676) 2024-03-04 17:31:01 +08:00
65cfd4360a fix: typo in wecom tool (#2674) 2024-03-04 17:25:42 +08:00
bbf5f42c87 fix: CE edition limits upload file nums (#2677) 2024-03-04 17:25:31 +08:00
3631e53ff0 Feat/add annotation migrate (#2675)
Co-authored-by: jyong <jyong@dify.ai>
2024-03-04 17:22:06 +08:00
f322d9bddb Fix vdb merge error (#2650) 2024-03-04 16:35:50 +08:00
05ce7b9d5e fix: deep copy customColletion (#2673) 2024-03-04 15:20:20 +08:00
72ddedfc5c fix: setup default filters while add credentials (#2669) 2024-03-04 14:17:00 +08:00
36686d7425 fix: test custom tool already exists without decrypting credentials (#2668) 2024-03-04 14:16:47 +08:00
34387ec0f1 fix typo recale to recalc (#2670) 2024-03-04 14:15:53 +08:00
83a6b0c626 Doc/update license (#2666) 2024-03-04 14:10:39 +08:00
76da66fb7e fix: fix import from explore apps err when OpenAI not inited (#2671) 2024-03-04 14:06:54 +08:00
607f9eda35 Fix/app runner typo (#2661) 2024-03-04 13:32:17 +08:00
f25cec265d feat: add Wecom(企业微信) tool for sending message to chat group bot via webhook (#2638) 2024-03-04 10:27:20 +08:00
8e66b96221 Feat: Add documents limitation (#2662) 2024-03-03 12:45:06 +08:00
b5c1bb346c Add PubMed to tools (#2652) 2024-03-03 12:44:13 +08:00
e94b323e6c fix: use English as the default i18n language (#2663) 2024-03-03 12:35:28 +08:00
bc65ee10c0 bugfix: model str maybe empty (#2660) 2024-03-03 11:43:38 +08:00
2001483659 fix: default to allcategories when search params is not from recommended (#2653) 2024-03-02 17:11:25 +08:00
444aba55dd Feat/jpn support (#2651) 2024-03-02 13:47:51 +08:00
3f640b1037 fix: click tool item in app debug page would show detail (#2644) 2024-03-01 18:47:12 +08:00
b07084711c fix: missing description (#2643) 2024-03-01 18:19:04 +08:00
fa8ab2134f feat: displaying the tool description when clicking on a custom tool (#2642) 2024-03-01 17:58:38 +08:00
1a677da792 fix: custom tool max tool (#2641) 2024-03-01 16:43:47 +08:00
b6d61a818e fix: Replace path.join with urljoin. (#2631) 2024-03-01 13:07:15 +08:00
8495ffaa45 fix: typo in gaode tool (#2636) 2024-03-01 10:12:48 +08:00
dbd1d79770 FEAT: Add arxiv tool for searching scientific papers and articles fro… (#2632) 2024-02-29 19:46:10 +08:00
1910178199 fix: default mail type invalid in .env.example (#2628) 2024-02-29 17:29:48 +08:00
839a6a2c8a add logs for vdb-migrate command (#2626) 2024-02-29 16:24:51 +08:00
a769edbc89 Fix/custom tool any of (#2625) 2024-02-29 14:39:05 +08:00
57ffecd0e5 fix: remove unnecessary credentials of custom tool (#2621) 2024-02-29 12:58:12 +08:00
801d135390 generalize the generation of new collection name by dataset id (#2620) 2024-02-29 12:47:10 +08:00
0428f44113 chore: bump superlinter action from v5 to v6 (#2325) 2024-02-29 12:45:06 +08:00
7beff3fd5a fix: model parameter load presets config (#2622) 2024-02-29 12:43:46 +08:00
88a095e40e fix: wrong default model parameters when creating app (#2623) 2024-02-29 12:43:07 +08:00
dd961985f0 refactor: remove unused codes, move core/agent module into dataset retrieval feature (#2614) 2024-02-28 23:32:47 +08:00
d44b05a9e5 feat: support auth type like basic bearer and custom (#2613) 2024-02-28 23:19:08 +08:00
318 changed files with 8564 additions and 3088 deletions

View File

@ -41,6 +41,8 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup NodeJS
uses: actions/setup-node@v4
@ -60,11 +62,10 @@ jobs:
yarn run lint
- name: Super-linter
uses: super-linter/super-linter/slim@v5
uses: super-linter/super-linter/slim@v6
env:
BASH_SEVERITY: warning
DEFAULT_BRANCH: main
ERROR_ON_MISSING_EXEC_BIT: true
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true

3
.gitignore vendored
View File

@ -145,6 +145,9 @@ docker/volumes/db/data/*
docker/volumes/redis/data/*
docker/volumes/weaviate/*
docker/volumes/qdrant/*
docker/volumes/etcd/*
docker/volumes/minio/*
docker/volumes/milvus/*
sdks/python-client/build
sdks/python-client/dist

22
LICENSE
View File

@ -1,24 +1,26 @@
# Dify Open Source License
# Open Source License
The Dify project is licensed under the Apache License 2.0, with the following additional conditions:
Dify is licensed under the Apache License 2.0, with the following additional conditions:
1. Dify is permitted to be used for commercialization, such as using Dify as a "backend-as-a-service" for your other applications, or delivering it to enterprises as an application development platform. However, when the following conditions are met, you must contact the producer to obtain a commercial license:
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify.AI source code to operate a multi-tenant SaaS service that is similar to the Dify.AI service edition.
b. LOGO and copyright information: In the process of using Dify, you may not remove or modify the LOGO or copyright information in the Dify console.
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
Please contact business@dify.ai by email to inquire about licensing matters.
2. As a contributor, you should agree that your contributed code:
2. As a contributor, you should agree that:
a. The producer can adjust the open-source agreement to be more strict or relaxed.
b. Can be used for commercial purposes, such as Dify's cloud business.
a. The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary.
b. Your contributed code may be used for commercial purposes, including but not limited to its cloud business operations.
Apart from this, all other rights and restrictions follow the Apache License 2.0. If you need more detailed information, you can refer to the full version of Apache License 2.0.
Apart from the specific conditions mentioned above, all other rights and restrictions follow the Apache License 2.0. Detailed information about the Apache License 2.0 can be found at http://www.apache.org/licenses/LICENSE-2.0.
The interactive design of this product is protected by appearance patent.
© 2023 LangGenius, Inc.
© 2024 LangGenius, Inc.
----------

View File

@ -21,17 +21,6 @@
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
</p>
<p align="center">
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
</a>
<ul align="center" style="text-decoration: none; list-style: none;">
<li> US EST: 09:00 (9:00 AM)</li>
<li> CET: 15:00 (3:00 PM)</li>
<li> CST: 22:00 (10:00 PM)</li>
</ul>
</p>
<p align="center">
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs

View File

@ -82,7 +82,7 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
MULTIMODAL_SEND_IMAGE_FORMAT=base64
# Mail configuration, support: resend, smtp
MAIL_TYPE=resend
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
@ -131,4 +131,4 @@ UNSTRUCTURED_API_URL=
SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=
BATCH_UPLOAD_LIMIT=10
BATCH_UPLOAD_LIMIT=10

View File

@ -5,7 +5,7 @@
1. Start the docker-compose stack
The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
```bash
cd ../docker
docker-compose -f docker-compose.middleware.yaml -p dify up -d
@ -15,7 +15,7 @@
3. Generate a `SECRET_KEY` in the `.env` file.
```bash
openssl rand -base64 42
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
```
3.5 If you use annaconda, create a new environment and activate it
```bash
@ -46,7 +46,7 @@
```
pip install -r requirements.txt --upgrade --force-reinstall
```
6. Start backend:
```bash
flask run --host 0.0.0.0 --port=5001 --debug

View File

@ -26,6 +26,7 @@ from config import CloudEditionConfig, Config
from extensions import (
ext_celery,
ext_code_based_extension,
ext_compress,
ext_database,
ext_hosting_provider,
ext_login,
@ -96,6 +97,7 @@ def create_app(test_config=None) -> Flask:
def initialize_extensions(app):
# Since the application instance is now created, pass it to each Flask
# extension instance to bind it to the Flask application instance (app)
ext_compress.init_app(app)
ext_code_based_extension.init()
ext_database.init_app(app)
ext_migrate.init(app, db)

View File

@ -15,7 +15,7 @@ from libs.rsa import generate_key_pair
from models.account import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account
from models.model import Account, App, AppAnnotationSetting, MessageAnnotation
from models.provider import Provider, ProviderModel
@ -125,12 +125,121 @@ def reset_encrypt_key_pair():
@click.command('vdb-migrate', help='migrate vector db.')
def vdb_migrate():
@click.option('--scope', default='all', prompt=False, help='The scope of vector database to migrate, Default is All.')
def vdb_migrate(scope: str):
if scope in ['knowledge', 'all']:
migrate_knowledge_vector_database()
if scope in ['annotation', 'all']:
migrate_annotation_vector_database()
def migrate_annotation_vector_database():
"""
Migrate annotation datas to target vector database .
"""
click.echo(click.style('Start migrate annotation data.', fg='green'))
create_count = 0
skipped_count = 0
total_count = 0
page = 1
while True:
try:
# get apps info
apps = db.session.query(App).filter(
App.status == 'normal'
).order_by(App.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for app in apps:
total_count = total_count + 1
click.echo(f'Processing the {total_count} app {app.id}. '
+ f'{create_count} created, {skipped_count} skipped.')
try:
click.echo('Create app annotation index: {}'.format(app.id))
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app.id
).first()
if not app_annotation_setting:
skipped_count = skipped_count + 1
click.echo('App annotation setting is disabled: {}'.format(app.id))
continue
# get dataset_collection_binding info
dataset_collection_binding = db.session.query(DatasetCollectionBinding).filter(
DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id
).first()
if not dataset_collection_binding:
click.echo('App annotation collection binding is not exist: {}'.format(app.id))
continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
indexing_technique='high_quality',
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id
)
documents = []
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
metadata={
"annotation_id": annotation.id,
"app_id": app.id,
"doc_id": annotation.id
}
)
documents.append(document)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
click.echo(f"Start to migrate annotation, app_id: {app.id}.")
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index for app: {app.id}.',
fg='green'))
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index for app {app.id}.',
fg='red'))
raise e
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} annotations for app {app.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for app {app.id}.', fg='green'))
except Exception as e:
click.echo(click.style(f'Failed to created vector index for app {app.id}.', fg='red'))
raise e
click.echo(f'Successfully migrated app annotation {app.id}.')
create_count += 1
except Exception as e:
click.echo(
click.style('Create app annotation index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(
click.style(f'Congratulations! Create {create_count} app annotation indexes, and skipped {skipped_count} apps.',
fg='green'))
def migrate_knowledge_vector_database():
"""
Migrate vector database datas to target vector database .
"""
click.echo(click.style('Start migrate vector db.', fg='green'))
create_count = 0
skipped_count = 0
total_count = 0
config = current_app.config
vector_type = config.get('VECTOR_STORE')
page = 1
@ -143,14 +252,19 @@ def vdb_migrate():
page += 1
for dataset in datasets:
total_count = total_count + 1
click.echo(f'Processing the {total_count} dataset {dataset.id}. '
+ f'{create_count} created, ${skipped_count} skipped.')
try:
click.echo('Create dataset vdb index: {}'.format(dataset.id))
if dataset.index_struct_dict:
if dataset.index_struct_dict['type'] == vector_type:
skipped_count = skipped_count + 1
continue
collection_name = ''
if vector_type == "weaviate":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'weaviate',
"vector_store": {"class_prefix": collection_name}
@ -167,7 +281,7 @@ def vdb_migrate():
raise ValueError('Dataset Collection Bindings is not exist!')
else:
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'qdrant',
"vector_store": {"class_prefix": collection_name}
@ -176,7 +290,7 @@ def vdb_migrate():
elif vector_type == "milvus":
dataset_id = dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": 'milvus',
"vector_store": {"class_prefix": collection_name}
@ -186,11 +300,17 @@ def vdb_migrate():
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
vector = Vector(dataset)
click.echo(f"vdb_migrate {dataset.id}")
click.echo(f"Start to migrate dataset {dataset.id}.")
try:
vector.delete()
click.echo(
click.style(f'Successfully delete vector index {collection_name} for dataset {dataset.id}.',
fg='green'))
except Exception as e:
click.echo(
click.style(f'Failed to delete vector index {collection_name} for dataset {dataset.id}.',
fg='red'))
raise e
dataset_documents = db.session.query(DatasetDocument).filter(
@ -201,6 +321,7 @@ def vdb_migrate():
).all()
documents = []
segments_count = 0
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
@ -220,15 +341,22 @@ def vdb_migrate():
)
documents.append(document)
segments_count = segments_count + 1
if documents:
try:
click.echo(click.style(
f'Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.',
fg='green'))
vector.create(documents)
click.echo(
click.style(f'Successfully created vector index for dataset {dataset.id}.', fg='green'))
except Exception as e:
click.echo(click.style(f'Failed to created vector index for dataset {dataset.id}.', fg='red'))
raise e
click.echo(f"Dataset {dataset.id} create successfully.")
db.session.add(dataset)
db.session.commit()
click.echo(f'Successfully migrated dataset {dataset.id}.')
create_count += 1
except Exception as e:
db.session.rollback()
@ -237,7 +365,9 @@ def vdb_migrate():
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
click.echo(
click.style(f'Congratulations! Create {create_count} dataset indexes, and skipped {skipped_count} datasets.',
fg='green'))
def register_commands(app):

View File

@ -90,7 +90,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.5.7"
self.CURRENT_VERSION = "0.5.9"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@ -293,6 +293,8 @@ class Config:
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
class CloudEditionConfig(Config):

View File

@ -13,30 +13,14 @@ model_templates = {
'status': 'normal'
},
'model_config': {
'provider': 'openai',
'model_id': 'gpt-3.5-turbo-instruct',
'configs': {
'prompt_template': '',
'prompt_variables': [],
'completion_params': {
'max_token': 512,
'temperature': 1,
'top_p': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
}
},
'provider': '',
'model_id': '',
'configs': {},
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
"top_p": 1,
"presence_penalty": 0,
"frequency_penalty": 0
}
"completion_params": {}
}),
'user_input_form': json.dumps([
{
@ -64,30 +48,14 @@ model_templates = {
'status': 'normal'
},
'model_config': {
'provider': 'openai',
'model_id': 'gpt-3.5-turbo',
'configs': {
'prompt_template': '',
'prompt_variables': [],
'completion_params': {
'max_token': 512,
'temperature': 1,
'top_p': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
}
},
'provider': '',
'model_id': '',
'configs': {},
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
"top_p": 1,
"presence_penalty": 0,
"frequency_penalty": 0
}
"completion_params": {}
})
}
},

View File

@ -27,7 +27,9 @@ from fields.app_fields import (
from libs.login import login_required
from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.tool_manager import ToolManager
from core.entities.application_entities import AgentToolEntity
def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
@ -129,7 +131,7 @@ class AppListApi(Resource):
"No Default System Reasoning Model available. Please configure "
"in the Settings -> Model Provider.")
else:
model_config_dict["model"]["provider"] = default_model_entity.provider
model_config_dict["model"]["provider"] = default_model_entity.provider.provider
model_config_dict["model"]["name"] = default_model_entity.model
model_configuration = AppModelConfigService.validate_configuration(
@ -236,7 +238,42 @@ class AppApi(Resource):
def get(self, app_id):
"""Get app detail"""
app_id = str(app_id)
app = _get_app(app_id, current_user.current_tenant_id)
app: App = _get_app(app_id, current_user.current_tenant_id)
# get original app model config
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
except Exception as e:
pass
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
return app

View File

@ -88,7 +88,7 @@ class ChatMessageTextApi(Resource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)

View File

@ -1,3 +1,4 @@
import json
from flask import request
from flask_login import current_user
@ -7,6 +8,9 @@ from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.entities.application_entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
@ -38,6 +42,88 @@ class ModelConfigResource(Resource):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
except Exception as e:
continue
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
except Exception as e:
continue
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()
# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
db.session.add(new_app_model_config)
db.session.flush()

View File

@ -11,7 +11,7 @@ from controllers.console.datasets.error import (
UnsupportedFileTypeError,
)
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService
@ -39,6 +39,7 @@ class FileApi(Resource):
@login_required
@account_initialization_required
@marshal_with(file_fields)
@cloud_edition_billing_resource_check(resource='documents')
def post(self):
# get file from request

View File

@ -85,7 +85,7 @@ class ChatTextApi(InstalledAppResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=request.form['text'],
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)
return {'data': response.data.decode('latin1')}

View File

@ -82,6 +82,30 @@ class ToolBuiltinProviderIconApi(Resource):
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)
class ToolModelProviderIconApi(Resource):
@setup_required
def get(self, provider):
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)
class ToolModelProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')
args = parser.parse_args()
return ToolManageService.list_model_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
)
class ToolApiProviderAddApi(Resource):
@setup_required
@ -259,6 +283,7 @@ class ToolApiProviderPreviousTestApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json')
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json')
parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json')
@ -268,6 +293,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args['provider_name'] if args['provider_name'] else '',
args['tool_name'],
args['credentials'],
args['parameters'],
@ -281,6 +307,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')

View File

@ -56,6 +56,7 @@ def cloud_edition_billing_resource_check(resource: str,
members = features.members
apps = features.apps
vector_space = features.vector_space
documents_upload_quota = features.documents_upload_quota
annotation_quota_limit = features.annotation_quota_limit
if resource == 'members' and 0 < members.limit <= members.size:
@ -64,6 +65,13 @@ def cloud_edition_billing_resource_check(resource: str,
abort(403, error_msg)
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
abort(403, error_msg)
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
source = request.args.get('source')
if source == 'datasets':
abort(403, error_msg)
else:
return view(*args, **kwargs)
elif resource == 'workspace_custom' and not features.can_replace_logo:
abort(403, error_msg)
elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:

View File

@ -87,7 +87,7 @@ class TextApi(Resource):
tenant_id=app_model.tenant_id,
text=args['text'],
end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
voice=args['voice'] if args['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming']
)

View File

@ -28,6 +28,7 @@ class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
@cloud_edition_billing_resource_check('documents', 'dataset')
def post(self, tenant_id, dataset_id):
"""Create document by text."""
parser = reqparse.RequestParser()
@ -153,6 +154,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
@cloud_edition_billing_resource_check('vector_space', 'dataset')
@cloud_edition_billing_resource_check('documents', 'dataset')
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}

View File

@ -200,8 +200,8 @@ class DatasetSegmentApi(DatasetApiResource):
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args['segments'], document)
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form

View File

@ -89,6 +89,7 @@ def cloud_edition_billing_resource_check(resource: str,
members = features.members
apps = features.apps
vector_space = features.vector_space
documents_upload_quota = features.documents_upload_quota
if resource == 'members' and 0 < members.limit <= members.size:
raise Unauthorized(error_msg)
@ -96,6 +97,8 @@ def cloud_edition_billing_resource_check(resource: str,
raise Unauthorized(error_msg)
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
raise Unauthorized(error_msg)
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
raise Unauthorized(error_msg)
else:
return view(*args, **kwargs)

View File

@ -84,7 +84,7 @@ class TextApi(WebApiResource):
tenant_id=app_model.tenant_id,
text=request.form['text'],
end_user=end_user.external_user_id,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
voice=request.form['voice'] if request.form['voice'] else app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=False
)

View File

@ -1,49 +0,0 @@
from typing import cast
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class CalcTokenMixin:
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
"""
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
:param model_config:
:param messages:
:return:
"""
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return 0
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
return rest_tokens
class ExceededLLMTokensLimitError(Exception):
pass

View File

@ -1,361 +0,0 @@
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
from langchain.tools import BaseTool
from pydantic import root_validator
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_config: ModelConfigEntity = None
model_config: ModelConfigEntity
agent_llm_callback: Optional[AgentLLMCallback] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> BaseSingleActionAgent:
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_config=model_config,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = 0
for parameter_rule in self.model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
self.model_config.parameters['max_tokens'] = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
prompt_messages = lc_messages_to_prompt_messages(messages)
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
except Exception as e:
raise e
self.model_config.parameters['max_tokens'] = original_max_tokens
return True if result.message.tool_calls else False
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
# summarize messages if rest_tokens < 0
try:
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
model_instance = ModelInstance(
provider_model_bundle=self.model_config.provider_model_bundle,
model=self.model_config.model,
)
tools = []
for function in self.functions:
tool = PromptMessageTool(
**function
)
tools.append(tool)
result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=tools,
stream=False,
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
model_parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
ai_message = AIMessage(
content=result.message.content or "",
additional_kwargs={
'function_call': {
'id': result.message.tool_calls[0].id,
**result.message.tool_calls[0].function.dict()
} if result.message.tool_calls else None
}
)
agent_decision = _parse_ai_message(ai_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
@classmethod
def get_system_message(cls):
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
try:
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(
self.model_config,
messages,
**kwargs
)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def predict_new_summary(
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_config.provider == 'azure_openai':
model = model_config.model
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_config.credentials.get("base_model_name")
tiktoken_ = _import_tiktoken()
try:
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@ -1,306 +0,0 @@
import re
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
AIMessage,
BaseMessage,
HumanMessage,
OutputParserException,
get_buffer_string,
)
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{{{
"action": "Final Answer",
"action_input": "Final response to human"
}}}}
```"""
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_config: ModelConfigEntity = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def should_use_agent(self, query: str):
"""
return should use agent
Using the ReACT mode to determine whether an agent is needed is costly,
so it's better to just use an Agent for reasoning, which is cheaper.
:param query:
:return:
"""
return True
def plan(
self,
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observatons
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
messages = []
if prompts:
messages = prompts[0].to_messages()
prompt_messages = lc_messages_to_prompt_messages(messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
raise e
try:
agent_decision = self.output_parser.parse(full_output)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_model_config:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
self.moving_summary_index = len(intermediate_steps)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
if 'chat_history' in kwargs:
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: list[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
formatted_tools = "\n".join(tool_strings)
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_config.mode == "chat":
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
model_config: ModelConfigEntity,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
agent_llm_callback: Optional[AgentLLMCallback] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_config.mode == "chat":
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
model_config=model_config,
prompt=prompt,
callback_manager=callback_manager,
agent_llm_callback=agent_llm_callback,
parameters={
'temperature': 0.2,
'top_p': 0.3,
'max_tokens': 1500
}
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)

View File

@ -84,7 +84,7 @@ class AppRunner:
return rest_tokens
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
def recalc_llm_max_tokens(self, model_config: ModelConfigEntity,
prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_type_instance = model_config.provider_model_bundle.model_type_instance

View File

@ -1,4 +1,3 @@
import json
import logging
from typing import cast
@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
message_chain = self._init_message_chain(
message=message,
query=query
)
# init model instance
model_instance = ModelInstance(
@ -201,6 +195,10 @@ class AssistantApplicationRunner(AppRunner):
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
db.session.refresh(conversation)
db.session.refresh(message)
db.session.close()
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner(
@ -290,38 +288,6 @@ class AssistantApplicationRunner(AppRunner):
'pool': db_variables.variables
})
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
"""
Init MessageChain
:param message: message
:param query: query
:return:
"""
message_chain = MessageChain(
message_id=message.id,
type="AgentExecutor",
input=json.dumps({
"input": query
})
)
db.session.add(message_chain)
db.session.commit()
return message_chain
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
"""
Save MessageChain
:param message_chain: message chain
:param output_text: output text
:return:
"""
message_chain.output = json.dumps({
"output": output_text
})
db.session.commit()
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""

View File

@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.moderation.base import ModerationException
@ -181,7 +181,7 @@ class BasicApplicationRunner(AppRunner):
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recale_llm_max_tokens(
self.recalc_llm_max_tokens(
model_config=app_orchestration_config.model_config,
prompt_messages=prompt_messages
)
@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner):
model=app_orchestration_config.model_config.model
)
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_orchestration_config.model_config.parameters,

View File

@ -89,6 +89,10 @@ class GenerateTaskPipeline:
Process generate task pipeline.
:return:
"""
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()
if stream:
return self._process_stream_response()
else:
@ -303,6 +307,7 @@ class GenerateTaskPipeline:
.first()
)
db.session.refresh(agent_thought)
db.session.close()
if agent_thought:
response = {
@ -330,6 +335,8 @@ class GenerateTaskPipeline:
.filter(MessageFile.id == event.message_file_id)
.first()
)
db.session.close()
# get extension
if '.' in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
@ -413,6 +420,7 @@ class GenerateTaskPipeline:
usage = llm_result.usage
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
self._message.message_tokens = usage.prompt_tokens

View File

@ -201,7 +201,7 @@ class ApplicationManager:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
@ -233,8 +233,6 @@ class ApplicationManager:
else:
logger.exception(e)
raise e
finally:
db.session.remove()
def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
-> AppOrchestrationConfigEntity:
@ -651,6 +649,7 @@ class ApplicationManager:
db.session.add(conversation)
db.session.commit()
db.session.refresh(conversation)
else:
conversation = (
db.session.query(Conversation)
@ -689,6 +688,7 @@ class ApplicationManager:
db.session.add(message)
db.session.commit()
db.session.refresh(message)
for file in application_generate_entity.files:
message_file = MessageFile(

View File

@ -0,0 +1,8 @@
from enum import Enum
class PlanningStrategy(Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'

View File

@ -1,199 +0,0 @@
import logging
from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.application_queue_manager import ApplicationQueueManager
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.entities.application_entities import (
AgentEntity,
AppOrchestrationConfigEntity,
InvokeFrom,
ModelConfigEntity,
)
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import Message
logger = logging.getLogger(__name__)
class AgentRunnerFeature:
def __init__(self, tenant_id: str,
app_orchestration_config: AppOrchestrationConfigEntity,
model_config: ModelConfigEntity,
config: AgentEntity,
queue_manager: ApplicationQueueManager,
message: Message,
user_id: str,
agent_llm_callback: AgentLLMCallback,
callback: AgentLoopGatherCallbackHandler,
memory: Optional[TokenBufferMemory] = None,) -> None:
"""
Agent runner
:param tenant_id: tenant id
:param app_orchestration_config: app orchestration config
:param model_config: model config
:param config: dataset config
:param queue_manager: queue manager
:param message: message
:param user_id: user id
:param agent_llm_callback: agent llm callback
:param callback: callback
:param memory: memory
"""
self.tenant_id = tenant_id
self.app_orchestration_config = app_orchestration_config
self.model_config = model_config
self.config = config
self.queue_manager = queue_manager
self.message = message
self.user_id = user_id
self.agent_llm_callback = agent_llm_callback
self.callback = callback
self.memory = memory
def run(self, query: str,
invoke_from: InvokeFrom) -> Optional[str]:
"""
Retrieve agent loop result.
:param query: query
:param invoke_from: invoke from
:return:
"""
provider = self.config.provider
model = self.config.model
tool_configs = self.config.tools
# check model is support tool calling
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model,
credentials=self.model_config.credentials
)
if not model_schema:
return None
planning_strategy = PlanningStrategy.REACT
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.FUNCTION_CALL
tools = self.to_tools(
tool_configs=tool_configs,
invoke_from=invoke_from,
callbacks=[self.callback, DifyStdOutCallbackHandler()],
)
if len(tools) == 0:
return None
agent_configuration = AgentConfiguration(
strategy=planning_strategy,
model_config=self.model_config,
tools=tools,
memory=self.memory,
max_iterations=10,
max_execution_time=400.0,
early_stopping_method="generate",
agent_llm_callback=self.agent_llm_callback,
callbacks=[self.callback, DifyStdOutCallbackHandler()]
)
agent_executor = AgentExecutor(agent_configuration)
try:
# check if should use agent
should_use_agent = agent_executor.should_use_agent(query)
if not should_use_agent:
return None
result = agent_executor.run(query)
return result.output
except Exception as ex:
logger.exception("agent_executor run failed")
return None
def to_dataset_retriever_tool(self, tool_config: dict,
invoke_from: InvokeFrom) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tool_config: tool config
:param invoke_from: invoke from
"""
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=self.queue_manager,
app_id=self.message.app_id,
message_id=self.message.id,
user_id=self.user_id,
invoke_from=invoke_from
)
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == self.tenant_id,
Dataset.id == tool_config.get("id")
).first()
# pass if dataset is not available
if not dataset:
return None
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
return None
# get retrieval model config
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
}
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
# get score threshold
score_threshold = None
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
top_k=top_k,
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=show_retrieve_source,
retriever_from=invoke_from.to_source()
)
return tool

View File

@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner):
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
db.session.close()
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
@ -154,9 +155,9 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
convert tool to prompt message tool
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tenant_id=self.application_generate_entity.tenant_id,
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
agent_tool=tool,
agent_callback=self.agent_callback
)
tool_entity.load_variables(self.variables_pool)
@ -171,33 +172,11 @@ class BaseAssistantApplicationRunner(AppRunner):
}
)
runtime_parameters = {}
parameters = tool_entity.parameters or []
user_parameters = tool_entity.get_runtime_parameters() or []
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
@ -213,59 +192,16 @@ class BaseAssistantApplicationRunner(AppRunner):
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
return message_tool, tool_entity
@ -305,6 +241,9 @@ class BaseAssistantApplicationRunner(AppRunner):
tool_runtime_parameters = tool.get_runtime_parameters() or []
for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
@ -320,18 +259,17 @@ class BaseAssistantApplicationRunner(AppRunner):
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.LLM:
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
@ -404,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner):
created_by=self.user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append((
message_file,
message.save_as
))
db.session.commit()
db.session.close()
return result
def create_agent_thought(self, message_id: str, message: str,
@ -447,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner):
db.session.add(thought)
db.session.commit()
db.session.refresh(thought)
db.session.close()
self.agent_thought_count += 1
@ -464,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None:
agent_thought.thought = thought
@ -514,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit()
db.session.close()
def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]:
"""
@ -586,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
convert tool variables to db variables
"""
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables.updated_at = datetime.utcnow()
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit()
db.session.close()
def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
@ -644,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner):
if message.answer:
result.append(AssistantPromptMessage(content=message.answer))
db.session.close()
return result

View File

@ -28,6 +28,9 @@ from models.model import Conversation, Message
class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
_is_first_iteration = True
_ignore_observation_providers = ['wenxin']
def run(self, conversation: Conversation,
message: Message,
query: str,
@ -42,10 +45,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
agent_scratchpad: list[AgentScratchpadUnit] = []
self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
# check model mode
if self.app_orchestration_config.model_config.mode == "completion":
# TODO: stop words
if 'Observation' not in app_orchestration_config.model_config.stop:
if 'Observation' not in app_orchestration_config.model_config.stop:
if app_orchestration_config.model_config.provider not in self._ignore_observation_providers:
app_orchestration_config.model_config.stop.append('Observation')
# override inputs
@ -130,8 +131,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
input=query
)
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# recalc llm max tokens
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
@ -202,6 +203,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
)
)
scratchpad.thought = scratchpad.thought.strip() or 'I am thinking about how to help you'
agent_scratchpad.append(scratchpad)
# get llm usage
@ -255,9 +257,15 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
# invoke tool
error_response = None
try:
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
tool_response = tool_instance.invoke(
user_id=self.user_id,
tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
tool_parameters=tool_call_args
)
# transform tool response to llm friendly response
tool_response = self.transform_tool_invoke_messages(tool_response)
@ -466,7 +474,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
if isinstance(message, AssistantPromptMessage):
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content,
thought=message.content or 'I am thinking about how to help you',
action_str='',
action=None,
observation=None,
@ -546,7 +554,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
result = ''
for scratchpad in agent_scratchpad:
result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
result += (scratchpad.thought or '') + (scratchpad.action_str or '') + \
next_iteration.replace("{{observation}}", scratchpad.observation or 'It seems that no response is available')
return result
@ -621,21 +630,24 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
))
# add assistant message
if len(agent_scratchpad) > 0:
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
prompt_messages.append(AssistantPromptMessage(
content=(agent_scratchpad[-1].thought or '')
content=(agent_scratchpad[-1].thought or '') + (agent_scratchpad[-1].action_str or ''),
))
# add user message
if len(agent_scratchpad) > 0:
if len(agent_scratchpad) > 0 and not self._is_first_iteration:
prompt_messages.append(UserPromptMessage(
content=(agent_scratchpad[-1].observation or ''),
content=(agent_scratchpad[-1].observation or 'It seems that no response is available'),
))
self._is_first_iteration = False
return prompt_messages
elif mode == "completion":
# parse agent scratchpad
agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
self._is_first_iteration = False
# parse prompt messages
return [UserPromptMessage(
content=first_prompt.replace("{{instruction}}", instruction)

View File

@ -105,8 +105,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner):
messages_ids=message_file_ids
)
# recale llm max tokens
self.recale_llm_max_tokens(self.model_config, prompt_messages)
# recalc llm max tokens
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,

View File

@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import Generation, LLMResult
from langchain.schema.language_model import BaseLanguageModel
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):

View File

@ -12,9 +12,9 @@ from pydantic import root_validator
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import lc_messages_to_prompt_messages
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.third_party.langchain.llms.fake import FakeLLM
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):

View File

@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from core.chain.llm_chain import LLMChain
from core.entities.application_entities import ModelConfigEntity
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.

View File

@ -1,4 +1,3 @@
import enum
import logging
from typing import Optional, Union
@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
from langchain.tools import BaseTool
from pydantic import BaseModel, Extra
from core.agent.agent.agent_llm_callback import AgentLLMCallback
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import ModelConfigEntity
from core.entities.message_entities import prompt_messages_to_lc_messages
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.helper import moderation
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.errors.invoke import InvokeError
@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
class PlanningStrategy(str, enum.Enum):
ROUTER = 'router'
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
class AgentConfiguration(BaseModel):
strategy: PlanningStrategy
model_config: ModelConfigEntity
@ -62,28 +53,7 @@ class AgentExecutor:
self.agent = self._init_agent()
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_config=self.configuration.model_config,
tools=self.configuration.tools,
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
if self.configuration.memory else None, # used for read chat histories memory
summary_model_config=self.configuration.summary_model_config
if self.configuration.summary_model_config else None,
agent_llm_callback=self.configuration.agent_llm_callback,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.ROUTER:
if self.configuration.strategy == PlanningStrategy.ROUTER:
self.configuration.tools = [t for t in self.configuration.tools
if isinstance(t, DatasetRetrieverTool)
or isinstance(t, DatasetMultiRetrieverTool)]

View File

@ -2,9 +2,10 @@ from typing import Optional, cast
from langchain.tools import BaseTool
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.agent_entities import PlanningStrategy
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

View File

@ -0,0 +1,54 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolParameterCacheType(Enum):
PARAMETER = "tool_parameter"
class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_tool_parameter = redis_client.get(self.cache_key)
if cached_tool_parameter:
try:
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
cached_tool_parameter = json.loads(cached_tool_parameter)
except JSONDecodeError:
return None
return cached_tool_parameter
else:
return None
def set(self, parameters: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

View File

@ -82,6 +82,8 @@ class HostingConfiguration:
RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM),
RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM),
RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING),
RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING),
]
)
quotas.append(trial_quota)

View File

@ -62,7 +62,8 @@ class IndexingRunner:
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
# transform
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
processing_rule.to_dict())
# save segment
self._load_segments(dataset, dataset_document, documents)
@ -120,7 +121,8 @@ class IndexingRunner:
text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
# transform
documents = self._transform(index_processor, dataset, text_docs, processing_rule.to_dict())
documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
processing_rule.to_dict())
# save segment
self._load_segments(dataset, dataset_document, documents)
@ -186,7 +188,7 @@ class IndexingRunner:
first()
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type, processing_rule.to_dict()).init_index_processor()
index_processor = IndexProcessorFactory(index_type).init_index_processor()
self._load(
index_processor=index_processor,
dataset=dataset,
@ -414,9 +416,14 @@ class IndexingRunner:
if separator:
separator = separator.replace('\\n', '\n')
if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']:
chunk_overlap = segmentation['chunk_overlap']
else:
chunk_overlap = 0
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=segmentation.get('chunk_overlap', 0),
chunk_overlap=chunk_overlap,
fixed_separator=separator,
separators=["\n\n", "", ".", " ", ""],
embedding_model_instance=embedding_model_instance
@ -750,7 +757,7 @@ class IndexingRunner:
index_processor.load(dataset, documents)
def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
text_docs: list[Document], process_rule: dict) -> list[Document]:
text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]:
# get embedding model instance
embedding_model_instance = None
if dataset.indexing_technique == 'high_quality':
@ -768,7 +775,8 @@ class IndexingRunner:
)
documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
process_rule=process_rule)
process_rule=process_rule, tenant_id=dataset.tenant_id,
doc_language=doc_language)
return documents

View File

@ -47,11 +47,14 @@ class TokenBufferMemory:
files, message.app_model_config
)
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
for file_obj in file_objs:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=message.query))

View File

@ -17,7 +17,7 @@ class ModelType(Enum):
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
TTS = "tts"
# TEXT2IMG = "text2img"
TEXT2IMG = "text2img"
@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
@ -36,6 +36,8 @@ class ModelType(Enum):
return cls.SPEECH2TEXT
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
return cls.TTS
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION
else:
@ -59,10 +61,11 @@ class ModelType(Enum):
return 'tts'
elif self == self.MODERATION:
return 'moderation'
elif self == self.TEXT2IMG:
return 'text2img'
else:
raise ValueError(f'invalid model type {self}')
class FetchFrom(Enum):
"""
Enum class for fetch from.

View File

@ -0,0 +1,48 @@
from abc import abstractmethod
from typing import IO, Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
class Text2ImageModel(AIModel):
"""
Model class for text2img model.
"""
model_type: ModelType = ModelType.TEXT2IMG
def invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
try:
return self._invoke(model, credentials, prompt, model_parameters, user)
except Exception as e:
raise self._transform_invoke_error(e)
@abstractmethod
def _invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
raise NotImplementedError

View File

@ -7,6 +7,7 @@
- togetherai
- ollama
- mistralai
- groq
- replicate
- huggingface_hub
- zhipuai

View File

@ -21,7 +21,7 @@ class AnthropicProvider(ModelProvider):
# Use `claude-instant-1` model for validate,
model_instance.validate_credentials(
model='claude-instant-1',
model='claude-instant-1.2',
credentials=credentials
)
except CredentialsValidateFailedError as ex:

View File

@ -2,8 +2,8 @@ provider: anthropic
label:
en_US: Anthropic
description:
en_US: Anthropics powerful models, such as Claude 2 and Claude Instant.
zh_Hans: Anthropic 的强大模型,例如 Claude 2 和 Claude Instant
en_US: Anthropics powerful models, such as Claude 3.
zh_Hans: Anthropic 的强大模型,例如 Claude 3
icon_small:
en_US: icon_s_en.svg
icon_large:

View File

@ -0,0 +1,6 @@
- claude-3-opus-20240229
- claude-3-sonnet-20240229
- claude-2.1
- claude-instant-1.2
- claude-2
- claude-instant-1

View File

@ -34,3 +34,4 @@ pricing:
output: '24.00'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -0,0 +1,37 @@
model: claude-3-opus-20240229
label:
en_US: claude-3-opus-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '15.00'
output: '75.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,37 @@
model: claude-3-sonnet-20240229
label:
en_US: claude-3-sonnet-20240229
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '3.00'
output: '15.00'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,35 @@
model: claude-instant-1.2
label:
en_US: claude-instant-1.2
model_type: llm
features: [ ]
model_properties:
mode: chat
context_size: 100000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 4096
min: 1
max: 4096
- name: response_format
use_template: response_format
pricing:
input: '1.63'
output: '5.51'
unit: '0.000001'
currency: USD

View File

@ -33,3 +33,4 @@ pricing:
output: '5.51'
unit: '0.000001'
currency: USD
deprecated: true

View File

@ -1,18 +1,32 @@
import base64
import mimetypes
from collections.abc import Generator
from typing import Optional, Union
from typing import Optional, Union, cast
import anthropic
import requests
from anthropic import Anthropic, Stream
from anthropic.types import Completion, completion_create_params
from anthropic.types import (
ContentBlockDeltaEvent,
Message,
MessageDeltaEvent,
MessageStartEvent,
MessageStopEvent,
MessageStreamEvent,
completion_create_params,
)
from httpx import Timeout
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
@ -35,6 +49,7 @@ if you are not sure about the structure.
</instructions>
"""
class AnthropicLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
@ -55,54 +70,114 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return: full response or stream response chunk generator result
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
return self._chat_generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# transform model parameters from completion api of anthropic to chat api
if 'max_tokens_to_sample' in model_parameters:
model_parameters['max_tokens'] = model_parameters.pop('max_tokens_to_sample')
# init model client
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
if system:
extra_model_kwargs['system'] = system
# chat model
response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
"""
if 'response_format' in model_parameters and model_parameters['response_format']:
stop = stop or []
self._transform_json_prompts(
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
# chat model
self._transform_chat_json_prompts(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
response_format=model_parameters['response_format']
)
model_parameters.pop('response_format')
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def _transform_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
def _transform_chat_json_prompts(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
-> None:
"""
Transform json prompts
"""
if "```\n" not in stop:
stop.append("```\n")
if "\n```" not in stop:
stop.append("\n```")
# check if there is a system message
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
# override the system message
prompt_messages[0] = SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
.replace("{{instructions}}", prompt_messages[0].content)
.replace("{{block}}", response_format)
)
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
else:
# insert the system message
prompt_messages.insert(0, SystemPromptMessage(
content=ANTHROPIC_BLOCK_MODE_PROMPT
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
.replace("{{block}}", response_format)
))
prompt_messages.append(AssistantPromptMessage(
content=f"```{response_format}\n"
))
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@ -129,7 +204,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:return:
"""
try:
self._generate(
self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=[
@ -137,58 +212,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
],
model_parameters={
"temperature": 0,
"max_tokens_to_sample": 20,
"max_tokens": 20,
},
stream=False
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Message,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Invoke large language model
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = Anthropic(**credentials_kwargs)
extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop
if user:
extra_model_kwargs['metadata'] = completion_create_params.Metadata(user_id=user)
response = client.completions.create(
model=model,
prompt=self._convert_messages_to_prompt_anthropic(prompt_messages),
stream=stream,
**model_parameters,
**extra_model_kwargs
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Completion,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm response
Handle llm chat response
:param model: model name
:param credentials: credentials
@ -198,75 +232,89 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=response.completion
content=response.content[0].text
)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if response.usage:
# transform usage
prompt_tokens = response.usage.input_tokens
completion_tokens = response.usage.output_tokens
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
response = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
usage=usage
)
return result
return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Stream[Completion],
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Stream[MessageStreamEvent],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
Handle llm chat stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
:return: llm response chunk generator
"""
index = -1
full_assistant_content = ''
return_model = None
input_tokens = 0
output_tokens = 0
finish_reason = None
index = 0
for chunk in response:
content = chunk.completion
if chunk.stop_reason is None and (content is None or content == ''):
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=content if content else '',
)
index += 1
if chunk.stop_reason is not None:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if isinstance(chunk, MessageStartEvent):
return_model = chunk.message.model
input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
output_tokens = chunk.usage.output_tokens
finish_reason = chunk.delta.stop_reason
elif isinstance(chunk, MessageStopEvent):
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
yield LLMResultChunk(
model=chunk.model,
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=chunk.stop_reason,
index=index + 1,
message=AssistantPromptMessage(
content=''
),
finish_reason=finish_reason,
usage=usage
)
)
else:
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ''
full_assistant_content += chunk_text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text
)
index = chunk.index
yield LLMResultChunk(
model=chunk.model,
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
index=chunk.index,
message=assistant_prompt_message,
)
)
@ -289,6 +337,80 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return credentials_kwargs
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]:
"""
Convert prompt messages to dict list and system
"""
system = ""
prompt_message_dicts = []
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
system += message.content + ("\n" if not system else "")
else:
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
return system, prompt_message_dicts
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
"""
Convert PromptMessage to dict
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {
"type": "text",
"text": message_content.data
}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"):
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
mime_type, _ = mimetypes.guess_type(message_content.data)
base64_data = base64.b64encode(image_content).decode('utf-8')
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]:
raise ValueError(f"Unsupported image type {mime_type}, "
f"only support image/jpeg, image/png, image/gif, and image/webp")
sub_message_dict = {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data
}
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.
@ -302,8 +424,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{human_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{human_prompt} [IMAGE]"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
if not isinstance(message.content, list):
message_text = f"{ai_prompt} {content}"
else:
message_text = ""
for sub_message in message.content:
if sub_message.type == PromptMessageContentType.TEXT:
message_text += f"{ai_prompt} {sub_message.data}"
elif sub_message.type == PromptMessageContentType.IMAGE:
message_text += f"{ai_prompt} [IMAGE]"
elif isinstance(message, SystemPromptMessage):
message_text = content
else:

View File

@ -524,5 +524,172 @@ EMBEDDING_BASE_MODELS = [
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='text-embedding-3-small',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: 8191,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.00002,
unit=0.001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='text-embedding-3-large',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: 8191,
ModelPropertyKey.MAX_CHUNKS: 32,
},
pricing=PriceConfig(
input=0.00013,
unit=0.001,
currency='USD',
)
)
)
]
SPEECH2TEXT_BASE_MODELS = [
AzureBaseModel(
base_model_name='whisper-1',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={
ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm'
}
)
)
]
TTS_BASE_MODELS = [
AzureBaseModel(
base_model_name='tts-1',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
ModelPropertyKey.VOICES: [
{
'mode': 'alloy',
'name': 'Alloy',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'echo',
'name': 'Echo',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'fable',
'name': 'Fable',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'onyx',
'name': 'Onyx',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'nova',
'name': 'Nova',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'shimmer',
'name': 'Shimmer',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
],
ModelPropertyKey.WORD_LIMIT: 120,
ModelPropertyKey.AUDOI_TYPE: 'mp3',
ModelPropertyKey.MAX_WORKERS: 5
},
pricing=PriceConfig(
input=0.015,
unit=0.001,
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='tts-1-hd',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label'
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={
ModelPropertyKey.DEFAULT_VOICE: 'alloy',
ModelPropertyKey.VOICES: [
{
'mode': 'alloy',
'name': 'Alloy',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'echo',
'name': 'Echo',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'fable',
'name': 'Fable',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'onyx',
'name': 'Onyx',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'nova',
'name': 'Nova',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
{
'mode': 'shimmer',
'name': 'Shimmer',
'language': ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID']
},
],
ModelPropertyKey.WORD_LIMIT: 120,
ModelPropertyKey.AUDOI_TYPE: 'mp3',
ModelPropertyKey.MAX_WORKERS: 5
},
pricing=PriceConfig(
input=0.03,
unit=0.001,
currency='USD',
)
)
)
]

View File

@ -15,6 +15,8 @@ help:
supported_model_types:
- llm
- text-embedding
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
@ -99,6 +101,36 @@ model_credential_schema:
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-small
value: text-embedding-3-small
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-large
value: text-embedding-3-large
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: whisper-1
value: whisper-1
show_on:
- variable: __model_type
value: speech2text
- label:
en_US: tts-1
value: tts-1
show_on:
- variable: __model_type
value: tts
- label:
en_US: tts-1-hd
value: tts-1-hd
show_on:
- variable: __model_type
value: tts
placeholder:
zh_Hans: 在此输入您的模型版本
en_US: Enter your model version

View File

@ -0,0 +1,82 @@
import copy
from typing import IO, Optional
from openai import AzureOpenAI
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
"""
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
return self._speech2text_invoke(model, credentials, file)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
audio_file_path = self._get_demo_file_path()
with open(audio_file_path, 'rb') as audio_file:
self._speech2text_invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:return: text for given audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# init model client
client = AzureOpenAI(**credentials_kwargs)
response = client.audio.transcriptions.create(model=model, file=file)
return response.text
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None

View File

@ -0,0 +1,174 @@
import concurrent.futures
import copy
from functools import reduce
from io import BytesIO
from typing import Optional
from flask import Response, stream_with_context
from openai import AzureOpenAI
from pydub import AudioSegment
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel
from extensions.ext_storage import storage
class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
"""
def _invoke(self, model: str, tenant_id: str, credentials: dict,
content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any:
"""
_invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:param streaming: output is streaming
:param user: unique user id
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,
credentials=credentials,
content_text=content_text,
tenant_id=tenant_id,
voice=voice)),
status=200, mimetype=f'audio/{audio_type}')
else:
return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice)
def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
"""
validate credentials text2speech model
:param model: model name
:param credentials: model credentials
:param user: unique user id
:return: text translated to audio file
"""
try:
self._tts_invoke(
model=model,
credentials=credentials,
content_text='Hello Dify!',
voice=self._get_model_default_voice(model, credentials),
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response:
"""
_tts_invoke text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
max_workers = self._get_model_workers_limit(model, credentials)
try:
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
audio_bytes_list = list()
# Create a thread pool and map the function to the list of sentences
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice,
credentials=credentials) for sentence in sentences]
for future in futures:
try:
if future.result():
audio_bytes_list.append(future.result())
except Exception as ex:
raise InvokeBadRequestError(str(ex))
if len(audio_bytes_list) > 0:
audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in
audio_bytes_list if audio_bytes]
combined_segment = reduce(lambda x, y: x + y, audio_segments)
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
buffer.seek(0)
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
except Exception as ex:
raise InvokeBadRequestError(str(ex))
# Todo: To improve the streaming function
def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str,
voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model timbre
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
voice = self._get_model_default_voice(model, credentials)
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
tts_file_id = self._get_file_name(content_text)
file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}'
try:
client = AzureOpenAI(**credentials_kwargs)
sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit))
for sentence in sentences:
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
# response.stream_to_file(file_path)
storage.save(file_path, response.read())
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _process_sentence(self, sentence: str, model: str,
voice, credentials: dict):
"""
_tts_invoke openai text2speech model api
:param model: model name
:param credentials: model credentials
:param voice: model timbre
:param sentence: text content to be translated
:return: text translated to audio file
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = AzureOpenAI(**credentials_kwargs)
response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip())
if isinstance(response.read(), bytes):
return response.read()
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
return ai_model_entity.entity
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
for ai_model_entity in TTS_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
return None

View File

@ -108,7 +108,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(e)
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:

View File

@ -472,7 +472,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
else:
raise ValueError(f"Got unknown type {message}")
if message.name is not None:
if message.name:
message_dict["user_name"] = message.name
return message_dict

View File

@ -0,0 +1,11 @@
<svg width="112" height="24" viewBox="0 0 112 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M57.4336 17.092C56.4746 16.5453 55.7216 15.7924 55.1749 14.8244C54.6283 13.8564 54.3594 12.763 54.3594 11.544C54.3594 10.3251 54.6283 9.2137 55.1749 8.24571C55.7216 7.27772 56.4746 6.52485 57.4336 5.98708C58.3926 5.4493 59.4861 5.18042 60.6961 5.18042C61.6999 5.18042 62.623 5.3776 63.4476 5.77197C64.2722 6.16633 64.9445 6.73995 65.4554 7.49284L64.568 8.13816C64.1199 7.51076 63.5642 7.04469 62.9009 6.731C62.2377 6.41729 61.5027 6.26492 60.705 6.26492C59.7281 6.26492 58.8498 6.48899 58.0789 6.92818C57.2992 7.36736 56.6986 7.98579 56.2505 8.79244C55.8113 9.59014 55.5872 10.5133 55.5872 11.553C55.5872 12.5926 55.8113 13.5159 56.2505 14.3136C56.6896 15.1112 57.2992 15.7297 58.0789 16.1778C58.8587 16.617 59.7281 16.8411 60.705 16.8411C61.5027 16.8411 62.2377 16.6888 62.9009 16.375C63.5642 16.0613 64.1199 15.5953 64.568 14.9678L65.4554 15.6132C64.9445 16.366 64.2722 16.9396 63.4476 17.334C62.623 17.7284 61.7089 17.9255 60.6961 17.9255C59.4771 17.9255 58.3926 17.6568 57.4336 17.11V17.092Z" fill="#F55036"/>
<path d="M67.2754 0H68.4763V17.8181H67.2754V0Z" fill="#F55036"/>
<path d="M73.6754 17.092C72.7254 16.5454 71.9725 15.7924 71.4347 14.8244C70.888 13.8564 70.6191 12.763 70.6191 11.544C70.6191 10.3251 70.888 9.23163 71.4347 8.26364C71.9814 7.29566 72.7254 6.54277 73.6754 5.99604C74.6255 5.4493 75.6921 5.18042 76.8841 5.18042C78.0762 5.18042 79.1338 5.4493 80.0928 5.99604C81.0429 6.54277 81.7957 7.29566 82.3335 8.26364C82.8803 9.23163 83.1492 10.3251 83.1492 11.544C83.1492 12.763 82.8803 13.8564 82.3335 14.8244C81.7868 15.7924 81.0429 16.5454 80.0928 17.092C79.1427 17.6387 78.0673 17.9076 76.8841 17.9076C75.7011 17.9076 74.6344 17.6387 73.6754 17.092ZM79.4655 16.1599C80.2273 15.7118 80.8277 15.0843 81.2669 14.2867C81.7062 13.489 81.9302 12.5747 81.9302 11.553C81.9302 10.5312 81.7062 9.61703 81.2669 8.81933C80.8277 8.02164 80.2273 7.39425 79.4655 6.9461C78.7036 6.49796 77.8431 6.27389 76.8841 6.27389C75.9251 6.27389 75.0646 6.49796 74.3028 6.9461C73.5409 7.39425 72.9405 8.02164 72.5013 8.81933C72.0621 9.61703 71.838 10.5312 71.838 11.553C71.838 12.5747 72.0621 13.489 72.5013 14.2867C72.9405 15.0843 73.5409 15.7118 74.3028 16.1599C75.0646 16.608 75.9251 16.8322 76.8841 16.8322C77.8431 16.8322 78.7036 16.608 79.4655 16.1599Z" fill="#F55036"/>
<path d="M96.2799 5.27905V17.8091H95.1237V15.1203C94.7114 15.9986 94.0929 16.6887 93.2774 17.1728C92.4618 17.6567 91.5027 17.9077 90.4003 17.9077C88.769 17.9077 87.4873 17.4506 86.5553 16.5364C85.6231 15.6222 85.166 14.3136 85.166 12.6017V5.27905H86.367V12.5031C86.367 13.9102 86.7255 14.9858 87.4515 15.7207C88.1775 16.4557 89.1903 16.8232 90.4989 16.8232C91.9061 16.8232 93.0264 16.384 93.851 15.5057C94.6756 14.6272 95.0878 13.4442 95.0878 11.9563V5.27905H96.2889H96.2799Z" fill="#F55036"/>
<path d="M110.952 0V17.8181H109.777V14.8604C109.284 15.8374 108.585 16.5902 107.689 17.119C106.793 17.6479 105.78 17.9077 104.642 17.9077C103.503 17.9077 102.419 17.6389 101.469 17.0922C100.528 16.5454 99.7838 15.7925 99.246 14.8336C98.7083 13.8745 98.4395 12.781 98.4395 11.5441C98.4395 10.3073 98.7083 9.2138 99.246 8.24582C99.7838 7.27783 100.519 6.52496 101.469 5.98718C102.41 5.44941 103.468 5.18053 104.642 5.18053C105.816 5.18053 106.766 5.44044 107.653 5.96925C108.541 6.49807 109.24 7.23301 109.75 8.17411V0H110.952ZM107.295 16.16C108.057 15.7119 108.657 15.0844 109.096 14.2868C109.535 13.4891 109.759 12.5749 109.759 11.5531C109.759 10.5313 109.535 9.61713 109.096 8.81944C108.657 8.02174 108.057 7.39434 107.295 6.9462C106.533 6.49807 105.672 6.27399 104.713 6.27399C103.754 6.27399 102.894 6.49807 102.132 6.9462C101.37 7.39434 100.77 8.02174 100.331 8.81944C99.8914 9.61713 99.6673 10.5313 99.6673 11.5531C99.6673 12.5749 99.8914 13.4891 100.331 14.2868C100.77 15.0844 101.37 15.7119 102.132 16.16C102.894 16.6081 103.754 16.8322 104.713 16.8322C105.672 16.8322 106.533 16.6081 107.295 16.16Z" fill="#F55036"/>
<path d="M30.6085 5.27024C27.077 5.27024 24.209 8.13835 24.209 11.6697C24.209 15.201 27.077 18.0692 30.6085 18.0692C34.1399 18.0692 37.0079 15.201 37.0079 11.6697C37.0079 8.13835 34.1399 5.27921 30.6085 5.27024ZM30.6085 15.6672C28.4036 15.6672 26.611 13.8746 26.611 11.6697C26.611 9.46486 28.4036 7.67228 30.6085 7.67228C32.8133 7.67228 34.6059 9.46486 34.6059 11.6697C34.6059 13.8746 32.8133 15.6672 30.6085 15.6672Z" fill="black"/>
<path d="M6.45358 5.23422C2.92222 5.19837 0.036187 8.0396 0.000335591 11.571C-0.0355158 15.1023 2.80571 17.9974 6.33706 18.0242C6.37292 18.0242 6.41773 18.0242 6.45358 18.0242H8.55986V15.6311H6.45358C4.24873 15.658 2.43823 13.8923 2.41134 11.6785C2.38445 9.47365 4.15014 7.66315 6.36395 7.63626C6.39084 7.63626 6.4267 7.63626 6.45358 7.63626C8.65844 7.63626 10.46 9.42884 10.46 11.6337V17.5222C10.46 19.7092 8.67637 21.4929 6.48943 21.5197C5.44078 21.5197 4.44591 21.0895 3.71095 20.3455L2.01698 22.0395C3.1911 23.2227 4.7865 23.8949 6.45358 23.9128H6.54321C10.0298 23.859 12.8351 21.0357 12.853 17.5491V11.4724C12.7635 8.00374 9.93116 5.23422 6.46254 5.23422H6.45358Z" fill="black"/>
<path d="M51.2406 11.5082C51.151 8.03961 48.3187 5.27009 44.8501 5.27009C41.3187 5.23423 38.4237 8.07545 38.3968 11.6068C38.361 15.1382 41.2022 18.0331 44.7335 18.0601C44.7694 18.0601 44.8143 18.0601 44.8501 18.0601H46.9563V15.667H44.8501C42.6452 15.6939 40.8347 13.9282 40.8078 11.7144C40.7809 9.5095 42.5467 7.69902 44.7604 7.67213C44.7874 7.67213 44.8232 7.67213 44.8501 7.67213C47.055 7.67213 48.8565 9.46469 48.8565 11.6696V23.626L51.2406 23.6528V11.5082Z" fill="black"/>
<path d="M14.6808 18.0602H17.0649V11.6607C17.0649 9.45589 18.8575 7.66332 21.0623 7.66332C21.7883 7.66332 22.4695 7.8605 23.0611 8.2011L24.2621 6.12172C23.3209 5.57498 22.2276 5.27024 21.0713 5.27024C17.5399 5.27024 14.6719 8.13835 14.6719 11.6697V18.0692L14.6808 18.0602Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 5.8 KiB

View File

@ -0,0 +1,4 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="24" height="24" rx="12" fill="#F55036"/>
<path d="M12.146 6.00022C9.87734 5.97718 8.02325 7.80249 8.00022 10.0712C7.97718 12.3398 9.80249 14.1997 12.0712 14.217C12.0942 14.217 12.123 14.217 12.146 14.217H13.4992V12.6796H12.146C10.7295 12.6968 9.56641 11.5625 9.54913 10.1403C9.53186 8.72377 10.6662 7.56065 12.0884 7.54337C12.1057 7.54337 12.1287 7.54337 12.146 7.54337C13.5625 7.54337 14.7199 8.69498 14.7199 10.1115V13.8945C14.7199 15.2995 13.574 16.4453 12.169 16.4626C11.4953 16.4626 10.8562 16.1862 10.384 15.7083L9.29578 16.7965C10.0501 17.5566 11.075 17.9885 12.146 18H12.2036C14.4435 17.9654 16.2457 16.1516 16.2572 13.9117V10.0078C16.1997 7.77945 14.3801 6.00022 12.1518 6.00022H12.146Z" fill="white"/>
</svg>

After

Width:  |  Height:  |  Size: 828 B

View File

@ -0,0 +1,29 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class GroqProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(
model='llama2-70b-4096',
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@ -0,0 +1,32 @@
provider: groq
label:
zh_Hans: GroqCloud
en_US: GroqCloud
description:
en_US: GroqCloud provides access to the Groq Cloud API, which hosts models like LLama2 and Mixtral.
zh_Hans: GroqCloud 提供对 Groq Cloud API 的访问,其中托管了 LLama2 和 Mixtral 等模型。
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.svg
background: "#F5F5F4"
help:
title:
en_US: Get your API Key from GroqCloud
zh_Hans: 从 GroqCloud 获取 API Key
url:
en_US: https://console.groq.com/
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key

View File

@ -0,0 +1,25 @@
model: llama2-70b-4096
label:
zh_Hans: Llama-2-70B-4096
en_US: Llama-2-70B-4096
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 4096
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 4096
pricing:
input: '0.7'
output: '0.8'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,26 @@
from collections.abc import Generator
from typing import Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials['mode'] = 'chat'
credentials['endpoint_url'] = 'https://api.groq.com/openai/v1'

View File

@ -0,0 +1,25 @@
model: mixtral-8x7b-32768
label:
zh_Hans: Mixtral-8x7b-Instruct-v0.1
en_US: Mixtral-8x7b-Instruct-v0.1
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 20480
pricing:
input: '0.27'
output: '0.27'
unit: '0.000001'
currency: USD

View File

@ -2,7 +2,7 @@ provider: jina
label:
en_US: Jina
description:
en_US: Embedding Model Supported
en_US: Embedding and Rerank Model Supported
icon_small:
en_US: icon_s_en.svg
icon_large:
@ -13,9 +13,10 @@ help:
en_US: Get your API key from Jina AI
zh_Hans: 从 Jina 获取 API Key
url:
en_US: https://jina.ai/embeddings/
en_US: https://jina.ai/
supported_model_types:
- text-embedding
- rerank
configurate_methods:
- predefined-model
provider_credential_schema:

View File

@ -0,0 +1,4 @@
model: jina-reranker-v1-base-en
model_type: rerank
model_properties:
context_size: 8192

View File

@ -0,0 +1,105 @@
from typing import Optional
import httpx
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class JinaRerankModel(RerankModel):
"""
Model class for Jina rerank model.
"""
def _invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
user: Optional[str] = None) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
try:
response = httpx.post(
"https://api.jina.ai/v1/rerank",
json={
"model": model,
"query": query,
"documents": docs,
"top_n": top_n
},
headers={"Authorization": f"Bearer {credentials.get('api_key')}"}
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results['results']:
rerank_document = RerankDocument(
index=result['index'],
text=result['document']['text'],
score=result['relevance_score'],
)
if score_threshold is None or result['relevance_score'] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError]
}

View File

@ -1,20 +1,32 @@
from os.path import abspath, dirname, join
from threading import Lock
from transformers import AutoTokenizer
class JinaTokenizer:
@staticmethod
def _get_num_tokens_by_jina_base(text: str) -> int:
_tokenizer = None
_lock = Lock()
@classmethod
def _get_tokenizer(cls):
if cls._tokenizer is None:
with cls._lock:
if cls._tokenizer is None:
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
return cls._tokenizer
@classmethod
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
"""
use jina tokenizer to get num tokens
"""
base_path = abspath(__file__)
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
tokenizer = cls._get_tokenizer()
tokens = tokenizer.encode(text)
return len(tokens)
@staticmethod
def get_num_tokens(text: str) -> int:
return JinaTokenizer._get_num_tokens_by_jina_base(text)
@classmethod
def get_num_tokens(cls, text: str) -> int:
return cls._get_num_tokens_by_jina_base(text)

View File

@ -57,7 +57,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(e)
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:

View File

@ -1,6 +1,6 @@
from collections.abc import Generator
from os.path import join
from typing import cast
from urllib.parse import urljoin
from httpx import Timeout
from openai import (
@ -313,10 +313,13 @@ class LocalAILarguageModel(LargeLanguageModel):
:param credentials: credentials dict
:return: client kwargs
"""
if not credentials['server_url'].endswith('/'):
credentials['server_url'] += '/'
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
"base_url": join(credentials['server_url'], 'v1'),
"base_url": urljoin(credentials['server_url'], 'v1'),
}
return client_kwargs

View File

@ -59,7 +59,7 @@ class LocalAITextEmbeddingModel(TextEmbeddingModel):
try:
response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10)
except Exception as e:
raise InvokeConnectionError(e)
raise InvokeConnectionError(str(e))
if response.status_code != 200:
try:

View File

@ -65,7 +65,7 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
try:
response = post(url, headers=headers, data=dumps(data))
except Exception as e:
raise InvokeConnectionError(e)
raise InvokeConnectionError(str(e))
if response.status_code != 200:
raise InvokeServerUnavailableError(response.text)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 KiB

After

Width:  |  Height:  |  Size: 7.2 KiB

View File

@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 8000
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.

View File

@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 8000
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.

View File

@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 8000
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.

View File

@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 2048
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.

View File

@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 8000
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.

View File

@ -2,4 +2,4 @@ model: whisper-1
model_type: speech2text
model_properties:
file_upload_limit: 25
supported_file_extensions: mp3,mp4,mpeg,mpga,m4a,wav,webm
supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm

View File

@ -34,7 +34,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
if not voice:
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,

View File

@ -25,6 +25,7 @@ from core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
@ -166,11 +167,23 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"""
generate custom model entities from credentials
"""
support_function_call = False
features = []
function_calling_type = credentials.get('function_calling_type', 'no_call')
if function_calling_type == 'function_call':
features = [ModelFeature.TOOL_CALL]
support_function_call = True
endpoint_url = credentials["endpoint_url"]
# if not endpoint_url.endswith('/'):
# endpoint_url += '/'
# if 'https://api.openai.com/v1/' == endpoint_url:
# features = [ModelFeature.STREAM_TOOL_CALL]
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
features=features if support_function_call else [],
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
ModelPropertyKey.MODE: credentials.get('mode'),
@ -194,14 +207,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
max=1,
precision=2
),
ParameterRule(
name="top_k",
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
default=int(credentials.get('top_k', 1)),
min=1,
max=100
),
ParameterRule(
name=DefaultParameterName.FREQUENCY_PENALTY.value,
label=I18nObject(en_US="Frequency Penalty"),
@ -232,7 +237,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
output=Decimal(credentials.get('output_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
),
)
if credentials['mode'] == 'chat':
@ -292,14 +297,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
raise ValueError("Unsupported completion type for model configuration.")
# annotate tools with names, descriptions, etc.
function_calling_type = credentials.get('function_calling_type', 'no_call')
formatted_tools = []
if tools:
data["tool_choice"] = "auto"
if function_calling_type == 'function_call':
data['functions'] = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
elif function_calling_type == 'tool_call':
data["tool_choice"] = "auto"
for tool in tools:
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
for tool in tools:
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
data["tools"] = formatted_tools
data["tools"] = formatted_tools
if stop:
data["stop"] = stop
@ -367,9 +380,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
if chunk:
#ignore sse comments
# ignore sse comments
if chunk.startswith(':'):
continue
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
chunk_json = None
try:
@ -452,10 +465,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
response_content = ''
tool_calls = None
function_calling_type = credentials.get('function_calling_type', 'no_call')
if completion_type is LLMMode.CHAT:
response_content = output.get('message', {})['content']
tool_calls = output.get('message', {}).get('tool_calls')
if function_calling_type == 'tool_call':
tool_calls = output.get('message', {}).get('tool_calls')
elif function_calling_type == 'function_call':
tool_calls = output.get('message', {}).get('function_call')
elif completion_type is LLMMode.COMPLETION:
response_content = output['text']
@ -463,7 +479,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
if function_calling_type == 'tool_call':
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
elif function_calling_type == 'function_call':
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
usage = response_json.get("usage")
if usage:
@ -522,33 +541,34 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
in
message.tool_calls]
# function_call = message.tool_calls[0]
# message_dict["function_call"] = {
# "name": function_call.function.name,
# "arguments": function_call.function.arguments,
# }
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
# in
# message.tool_calls]
function_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id
}
# message_dict = {
# "role": "function",
# "role": "tool",
# "content": message.content,
# "name": message.tool_call_id
# "tool_call_id": message.tool_call_id
# }
message_dict = {
"role": "function",
"content": message.content,
"name": message.tool_call_id
}
else:
raise ValueError(f"Got unknown type {message}")
if message.name is not None:
if message.name:
message_dict["name"] = message.name
return message_dict
@ -693,3 +713,26 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
tool_calls.append(tool_call)
return tool_calls
def _extract_response_function_call(self, response_function_call) \
-> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
:param response_function_call: response function call
:return: tool call
"""
tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call['name'],
arguments=response_function_call['arguments']
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call['name'],
type="function",
function=function
)
return tool_call

View File

@ -75,6 +75,28 @@ model_credential_schema:
value: llm
default: '4096'
type: text-input
- variable: function_calling_type
show_on:
- variable: __model_type
value: llm
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: function_call
label:
en_US: Support
zh_Hans: 支持
# - value: tool_call
# label:
# en_US: Tool Call
# zh_Hans: Tool Call
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- variable: stream_mode_delimiter
label:
zh_Hans: 流模式返回结果的分隔符

View File

@ -53,7 +53,7 @@ class OpenLLMTextEmbeddingModel(TextEmbeddingModel):
# cloud not connect to the server
raise InvokeAuthorizationError(f"Invalid server URL: {e}")
except Exception as e:
raise InvokeConnectionError(e)
raise InvokeConnectionError(str(e))
if response.status_code != 200:
if response.status_code == 400:

View File

@ -34,7 +34,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel):
:return: text translated to audio file
"""
audio_type = self._get_model_audio_type(model, credentials)
if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials):
if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]:
voice = self._get_model_default_voice(model, credentials)
if streaming:
return Response(stream_with_context(self._tts_invoke_streaming(model=model,

View File

@ -308,6 +308,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=credentials.get('context_length', 2048),
default=512,
label=I18nObject(
zh_Hans='最大生成长度',

Some files were not shown because too many files have changed in this diff Show More