Compare commits

..

138 Commits

Author SHA1 Message Date
4b0d2bf57f chore: update build-push.yml to remove unnecessary tags 2024-09-04 20:20:39 +08:00
94432a0a69 chore: update package versions to 0.8.0-beta1 (#7979) 2024-09-04 19:56:33 +08:00
7e30487f8b feat: update dsl version 2024-09-04 19:41:43 +08:00
Yi
46634638e7 fix: refine the "isInIteration" for workflow 2024-09-04 17:45:07 +08:00
44038b9628 fix: iteration copy 2024-09-04 17:26:43 +08:00
c625f4282f Merge branch 'main' into feat/workflow-parallel-support 2024-09-04 15:22:57 +08:00
4f5dc82459 fix 2024-09-04 15:03:35 +08:00
Yi
5cb018e15d update the method to check if a node is in iteration 2024-09-04 14:59:30 +08:00
4962b2c460 check node edge 2024-09-04 13:27:17 +08:00
Yi
cd42dbdae8 update the log for iteration nodes 2024-09-04 10:27:40 +08:00
78fa1f6868 fix(workflow): detached session issues 2024-09-03 18:23:37 +08:00
Yi
6bee121ebe update log in web app 2024-09-03 17:27:32 +08:00
36d95e49b0 fix(iteration): iterator_length not correct 2024-09-03 12:01:56 +08:00
Yi
3431b19f9a update styling and iteration log 2024-09-03 11:46:04 +08:00
Yi
b28c7b1cda Merge branch 'feat/workflow-parallel-support' of github.com:langgenius/dify into feat/workflow-parallel-support 2024-09-03 10:35:15 +08:00
Yi
83343eefe6 update parallel log 2024-09-03 10:34:50 +08:00
d92966545b fix: migration 2024-09-02 22:41:08 +08:00
f71c51cb9a Merge branch 'refs/heads/main' into feat/workflow-parallel-support 2024-09-02 22:37:23 +08:00
955884b87e chore(workflow): max thread submit count 2024-09-02 20:20:32 +08:00
5ca9df65de feat(workflow): add thread pool 2024-09-02 19:02:45 +08:00
166365a502 feat(workflow): add thread pool 2024-09-02 19:02:21 +08:00
70aced0100 fix 2024-09-02 18:38:21 +08:00
35d9c59a29 Merge remote-tracking branch 'origin/feat/workflow-parallel-support' into feat/workflow-parallel-support 2024-09-02 17:56:16 +08:00
bbc922dffa merge main 2024-09-02 17:55:28 +08:00
7035f64ce3 fix: next step 2024-09-02 17:52:54 +08:00
81d09d471c Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/workflow/app_generator.py
2024-09-02 17:52:51 +08:00
5bda3a384a fix(workflow): bugs 2024-09-02 17:49:51 +08:00
43240fcd41 fix 2024-09-02 14:50:05 +08:00
52b4623131 fix(workflow): fix merge branch node id err 2024-09-02 13:56:07 +08:00
0dabf799c0 fix(workflow): fix merge branch node id err 2024-09-02 11:52:14 +08:00
Yi
29b1ce781d fix: node end status 2024-09-01 22:00:54 +08:00
Yi
71a7d890cc fix styling 2024-08-30 23:31:05 +08:00
Yi
ee1587c939 fix: make the End node always nested in the root 2024-08-30 20:14:56 +08:00
Yi
d7c0ca852e feat: inner parallels will be added to its corresponding branch 2024-08-30 20:08:57 +08:00
162e9677c7 fix(workflow): missing parallel event in workflow app 2024-08-30 20:04:17 +08:00
77e62f7fee fix(workflow): run node in multi parallel bugs 2024-08-30 18:55:33 +08:00
Yi
e3295181d2 fix a typo 2024-08-30 18:01:13 +08:00
Yi
2b5b856126 solve the branch issue 2024-08-30 17:58:29 +08:00
Yi
e3ae529a55 update the onNodeFinished method for nodes being passed through more than once 2024-08-30 17:00:02 +08:00
Yi
708256ef1d Merge branch 'feat/workflow-parallel-support' of github.com:langgenius/dify into feat/workflow-parallel-support 2024-08-30 15:23:21 +08:00
7c9081a8fc fix 2024-08-30 13:44:01 +08:00
Yi
1bde57e591 delete console logs 2024-08-29 17:54:26 +08:00
Yi
32a11cbb6a update the parallel workflow log for iteration and chatflow preview 2024-08-29 17:26:17 +08:00
Yi
3e257ae907 update the workflow parallel log 2024-08-29 16:38:51 +08:00
f43596f226 fix: parallel branch limit 2024-08-29 11:31:34 +08:00
ae22015fe7 fix(workflow): loop check 2024-08-28 21:47:47 +08:00
790dd3b22f fix(workflow): duplicate nodes in parallel 2024-08-28 19:01:45 +08:00
5d34e080eb fix: migration 2024-08-28 18:02:49 +08:00
6b6750b9ad Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/services/app_generate_service.py
2024-08-28 18:01:57 +08:00
74c8004944 fix(graph_engine): fix execute loops in parallel 2024-08-28 17:42:42 +08:00
4418fa1d2b fix: bug 2024-08-28 17:40:50 +08:00
c2bb11405f fix(workflow): parallel not yield 2024-08-28 16:13:57 +08:00
8ba5673606 feat: iteration support parallel 2024-08-28 16:00:17 +08:00
b0a81c654b fix(workflow): parallel execution after if-else that only one branch runs 2024-08-28 15:53:39 +08:00
cd52633b0e fix(graph_engine): parent_parallel_id missing 2024-08-27 16:45:14 +08:00
4256e9d47f chore(iteration): keep start_node_id using in parallel start nodes 2024-08-27 16:38:33 +08:00
4e3dc36e37 fix: workflow run edge status 2024-08-27 14:39:56 +08:00
b9f34f679f fix: iteration start node id 2024-08-26 22:00:17 +08:00
9c8144e463 feat: parallel hover 2024-08-26 17:49:11 +08:00
76bb8d1c1a Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/services/app_generate_service.py
#	api/services/workflow_service.py
2024-08-26 16:17:19 +08:00
1016db160e feat: parallel hover 2024-08-26 16:09:22 +08:00
6c61776ee1 fix test 2024-08-25 22:02:21 +08:00
4771e85630 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/tests/integration_tests/workflow/nodes/test_code.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
2024-08-24 17:26:44 +08:00
85d319719c fix end node bug 2024-08-24 17:17:18 +08:00
42899fb3be fix bug 2024-08-23 00:38:42 +08:00
5b22d8f8b2 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/question_classifier/question_classifier_node.py
2024-08-23 00:32:28 +08:00
fe2b300288 fix lint 2024-08-22 23:54:07 +08:00
ec4fc784f0 fix iteration start node 2024-08-22 23:53:44 +08:00
d6da7b0336 fix dialogue_count 2024-08-22 13:06:17 +08:00
92072e2ed7 fix: ruff issues 2024-08-21 17:26:51 +08:00
e34497ded1 fix: merge issues 2024-08-21 17:25:26 +08:00
35be41b337 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/nodes/start/start_node.py
#	api/core/workflow/nodes/variable_assigner/__init__.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
2024-08-21 16:59:23 +08:00
412be6d014 fix bug 2024-08-21 16:43:00 +08:00
1d88b62e25 fix(workflow): fix node link to previous node issue 2024-08-20 23:28:11 +08:00
617ea4b3b8 fix(workflow): fix parallel bug 2024-08-20 22:16:41 +08:00
755a9658c7 fix(workflow): add parallel id into published events 2024-08-18 20:18:13 +08:00
5d7865737f fix(workflow): issues in workflow parallels 2024-08-16 22:47:58 +08:00
352c45c8a2 feat(workflow): integrate parallel into workflow apps 2024-08-16 21:33:09 +08:00
1973f5003b feat: frontend support parallel 2024-08-16 16:55:08 +08:00
5b5e6e31bf fix: answer node unit tests 2024-08-16 01:44:00 +08:00
90221c0a90 fix: unit tests 2024-08-16 01:43:35 +08:00
91e51ce2b8 fix(workflow): issues by merging main branch 2024-08-16 01:36:19 +08:00
db9b0ee985 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/base_app_runner.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/apps/workflow/generate_task_pipeline.py
#	api/core/app/task_pipeline/workflow_cycle_state_manager.py
#	api/core/workflow/entities/node_entities.py
#	api/core/workflow/nodes/llm/llm_node.py
#	api/core/workflow/workflow_engine_manager.py
#	api/tests/integration_tests/workflow/nodes/test_llm.py
#	api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py
#	api/tests/unit_tests/core/workflow/nodes/test_answer.py
#	api/tests/unit_tests/core/workflow/nodes/test_if_else.py
#	api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py
2024-08-16 01:19:29 +08:00
c5192650fb fix: unit tests in workflow 2024-08-15 23:47:59 +08:00
702df31db7 fix(workflow): fix generate issues in workflow 2024-08-15 20:45:23 +08:00
1da5862a96 feat(workflow): fix iteration single debug 2024-08-15 03:12:49 +08:00
6f6b32e1ee feat(workflow): integrate workflow entry with workflow app 2024-08-14 19:22:15 +08:00
674af04c39 fix migration version depends 2024-08-13 17:15:21 +08:00
2980e31ddf fix issues when merging from main 2024-08-13 17:11:19 +08:00
14d020fffe Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/app/apps/advanced_chat/generate_task_pipeline.py
#	api/core/app/apps/workflow/app_runner.py
#	api/core/app/task_pipeline/workflow_cycle_manage.py
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/base_node.py
#	api/core/workflow/workflow_engine_manager.py
2024-08-13 17:05:39 +08:00
8401a11109 feat(workflow): integrate workflow entry with advanced chat app 2024-08-13 16:21:10 +08:00
8d27ec364f fix bug 2024-07-31 02:27:23 +08:00
c9bb366e1a Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/entities/variable_pool.py
#	api/core/workflow/nodes/iteration/iteration_node.py
#	api/core/workflow/workflow_engine_manager.py
2024-07-31 02:25:31 +08:00
917aacbf7f add chatflow app event convert 2024-07-31 02:21:35 +08:00
0818b7b078 remove iteration special logic 2024-07-26 21:27:01 +08:00
88dcd7b737 fix bug 2024-07-26 20:29:12 +08:00
63addf8c94 add parallel branch events 2024-07-26 20:27:17 +08:00
483f71f03c fix logging 2024-07-26 20:13:11 +08:00
beea1e1663 fix lint 2024-07-26 19:47:12 +08:00
38f8c45755 add events in interation node 2024-07-26 19:47:02 +08:00
a31feacf28 fix iteration 2024-07-26 02:43:40 +08:00
ae351bd40e add iteration support 2024-07-25 23:07:27 +08:00
df133168dd fix lint 2024-07-25 21:06:23 +08:00
7c67ba8991 remove threadpool 2024-07-25 21:05:53 +08:00
4097f7c069 add parallel branch output 2024-07-25 19:39:06 +08:00
f4eb7cd037 add end stream output test 2024-07-25 04:03:53 +08:00
833584ba76 Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/workflow_entry.py
2024-07-24 23:43:14 +08:00
ec7760795f save 2024-07-24 00:24:24 +08:00
e9bfedab9b Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/entities/variable_pool.py
2024-07-23 17:28:57 +08:00
7303b53af1 fix bug 2024-07-23 16:18:52 +08:00
0fe516568a Merge branch 'refs/heads/main' into feat/workflow-parallel-support
# Conflicts:
#	api/core/workflow/nodes/code/code_node.py
#	api/core/workflow/nodes/end/end_node.py
#	api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
#	api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
#	api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
2024-07-23 16:18:34 +08:00
2c695ded79 fix bugs 2024-07-23 00:10:23 +08:00
a603e01f5e fix bug 2024-07-22 19:57:32 +08:00
beaac5033a fix bug 2024-07-20 00:57:41 +08:00
dad1a967ee finished answer stream output 2024-07-20 00:49:46 +08:00
7ad77e9e77 fix test 2024-07-18 08:19:58 +08:00
f67a88f44d fix test 2024-07-17 21:17:04 +08:00
90e518b05b fix bugs 2024-07-17 16:54:49 +08:00
cc96acdae3 fix bugs 2024-07-17 11:26:33 +08:00
16e2d00157 optimize 2024-07-17 01:07:23 +08:00
4ef3d4e65c optimize 2024-07-17 01:02:40 +08:00
775e52db4d merge 2024-07-16 17:46:20 +08:00
00ec36d47c add graph engine test 2024-07-16 16:37:37 +08:00
00fb23d0c9 graph engine implement 2024-07-15 23:40:02 +08:00
821e09b259 add run logics 2024-07-12 19:33:47 +08:00
d77b689a99 completed parallel tests 2024-07-10 21:21:06 +08:00
0e885a3cae refactor runtime 2024-07-08 16:29:13 +08:00
1adaf42f9d refactor graph 2024-07-07 23:08:45 +08:00
fed068ac2e Merge branch 'refs/heads/main' into feat/workflow-parallel-support 2024-07-07 16:57:21 +08:00
03f56a05eb refactor graph 2024-07-06 03:18:02 +08:00
1b6cd975f3 completed graph init test 2024-07-04 15:40:20 +08:00
0f19b2a986 optimize graph 2024-07-02 21:53:41 +08:00
8375517ccd save 2024-06-29 15:44:52 +08:00
1d8ecac093 save 2024-06-27 05:30:38 +08:00
aaa98c76d5 optimize 2024-06-26 23:56:30 +08:00
216910a4a1 add runtime state of graph 2024-06-25 17:43:13 +08:00
fe27c97fd9 add runtime graph 2024-06-25 14:41:14 +08:00
8217c46116 add new graph structure 2024-06-24 23:34:42 +08:00
1464 changed files with 28739 additions and 39844 deletions

View File

@ -125,7 +125,6 @@ jobs:
with:
images: ${{ env[matrix.image_name_env] }}
tags: |
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}

View File

@ -20,7 +20,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v44
with:
files: api/**
@ -66,7 +66,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v44
with:
files: web/**
@ -97,7 +97,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v45
uses: tj-actions/changed-files@v44
with:
files: |
**.sh
@ -107,7 +107,7 @@ jobs:
dev/**
- name: Super-linter
uses: super-linter/super-linter/slim@v7
uses: super-linter/super-linter/slim@v6
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

9
.gitignore vendored
View File

@ -153,9 +153,6 @@ docker-legacy/volumes/etcd/*
docker-legacy/volumes/minio/*
docker-legacy/volumes/milvus/*
docker-legacy/volumes/chroma/*
docker-legacy/volumes/opensearch/data/*
docker-legacy/volumes/pgvectors/data/*
docker-legacy/volumes/pgvector/data/*
docker/volumes/app/storage/*
docker/volumes/certbot/*
@ -167,12 +164,6 @@ docker/volumes/etcd/*
docker/volumes/minio/*
docker/volumes/milvus/*
docker/volumes/chroma/*
docker/volumes/opensearch/data/*
docker/volumes/myscale/data/*
docker/volumes/myscale/log/*
docker/volumes/unstructured/*
docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/nginx/conf.d/default.conf
docker/middleware.env

View File

@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.

View File

@ -39,7 +39,7 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos
# storage type: local, s3, azure-blob, google-storage
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false
@ -73,12 +73,6 @@ TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
# OCI Storage configuration
OCI_ENDPOINT=your-endpoint
OCI_BUCKET_NAME=your-bucket-name
@ -86,13 +80,6 @@ OCI_ACCESS_KEY=your-access-key
OCI_SECRET_KEY=your-secret-key
OCI_REGION=your-region
# Volcengine tos Storage configuration
VOLCENGINE_TOS_ENDPOINT=your-endpoint
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
VOLCENGINE_TOS_REGION=your-region
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
@ -114,10 +101,11 @@ QDRANT_GRPC_ENABLED=false
QDRANT_GRPC_PORT=6334
# Milvus configuration
MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=
MILVUS_HOST=127.0.0.1
MILVUS_PORT=19530
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# MyScale configuration
MYSCALE_HOST=127.0.0.1

View File

@ -55,7 +55,7 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*

View File

@ -164,7 +164,7 @@ def initialize_extensions(app):
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
if request.blueprint not in ["console", "inner_api"]:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "")

View File

@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm):
)
@click.confirmation_option(
prompt=click.style(
"Are you sure you want to reset encrypt key pair? this operation cannot be rolled back!", fg="red"
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
)
)
def reset_encrypt_key_pair():
@ -131,7 +131,7 @@ def reset_encrypt_key_pair():
click.echo(
click.style(
"Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id),
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green",
)
)
@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
@click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str):
if scope in {"knowledge", "all"}:
if scope in ["knowledge", "all"]:
migrate_knowledge_vector_database()
if scope in {"annotation", "all"}:
if scope in ["annotation", "all"]:
migrate_annotation_vector_database()
@ -275,7 +275,8 @@ def migrate_knowledge_vector_database():
for dataset in datasets:
total_count = total_count + 1
click.echo(
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
f"Processing the {total_count} dataset {dataset.id}. "
+ f"{create_count} created, {skipped_count} skipped."
)
try:
click.echo("Create dataset vdb index: {}".format(dataset.id))
@ -410,8 +411,7 @@ def migrate_knowledge_vector_database():
try:
click.echo(
click.style(
f"Start to created vector index with {len(documents)} documents of {segments_count}"
f" segments for dataset {dataset.id}.",
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
fg="green",
)
)
@ -593,7 +593,7 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
click.echo(
click.style(
"Congratulations! Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
fg="green",
)
)

View File

@ -46,7 +46,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
"""
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="endpoint URL of code execution service",
description="endpoint URL of code execution servcie",
default="http://sandbox:8194",
)
@ -129,12 +129,12 @@ class EndpointConfig(BaseSettings):
)
SERVICE_API_URL: str = Field(
description="Service API Url prefix. used to display Service API Base Url to the front-end.",
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
default="",
)
APP_WEB_URL: str = Field(
description="WebApp Url prefix. used to display WebAPP API Base Url to the front-end.",
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
default="",
)
@ -272,7 +272,7 @@ class LoggingConfig(BaseSettings):
"""
LOG_LEVEL: str = Field(
description="Log output level, default to INFO. It is recommended to set it to ERROR for production.",
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
default="INFO",
)
@ -415,7 +415,7 @@ class MailConfig(BaseSettings):
"""
MAIL_TYPE: Optional[str] = Field(
description="Mail provider type name, default to None, available values are `smtp` and `resend`.",
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
default=None,
)

View File

@ -1,7 +1,7 @@
from typing import Any, Optional
from urllib.parse import quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
@ -9,10 +9,8 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
@ -159,21 +157,6 @@ class CeleryConfig(DatabaseConfig):
default=None,
)
CELERY_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
description="Redis Sentinel master name",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)
@computed_field
@property
def CELERY_RESULT_BACKEND(self) -> str | None:
@ -201,8 +184,6 @@ class MiddlewareConfig(
AzureBlobStorageConfig,
GoogleCloudStorageConfig,
TencentCloudCOSStorageConfig,
HuaweiCloudOBSStorageConfig,
VolcengineTOSStorageConfig,
S3StorageConfig,
OCIStorageConfig,
# configs of vdb and vdb providers

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings
@ -38,33 +38,3 @@ class RedisConfig(BaseSettings):
description="whether to use SSL for Redis connection",
default=False,
)
REDIS_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
REDIS_SENTINELS: Optional[str] = Field(
description="Redis Sentinel nodes",
default=None,
)
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
description="Redis Sentinel service name",
default=None,
)
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
description="Redis Sentinel username",
default=None,
)
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
description="Redis Sentinel password",
default=None,
)
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)

View File

@ -1,29 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class HuaweiCloudOBSStorageConfig(BaseModel):
"""
Huawei Cloud OBS storage configs
"""
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
description="Huawei Cloud OBS bucket name",
default=None,
)
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Access key",
default=None,
)
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Secret key",
default=None,
)
HUAWEI_OBS_SERVER: Optional[str] = Field(
description="Huawei Cloud OBS server URL",
default=None,
)

View File

@ -1,34 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class VolcengineTOSStorageConfig(BaseModel):
"""
Volcengine tos storage configs
"""
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
description="Volcengine TOS Bucket Name",
default=None,
)
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
description="Volcengine TOS Access Key",
default=None,
)
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
description="Volcengine TOS Secret Key",
default=None,
)
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
description="Volcengine TOS Endpoint URL",
default=None,
)
VOLCENGINE_TOS_REGION: Optional[str] = Field(
description="Volcengine TOS Region",
default=None,
)

View File

@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
@ -9,14 +9,14 @@ class MilvusConfig(BaseSettings):
Milvus configs
"""
MILVUS_URI: Optional[str] = Field(
description="Milvus uri",
default="http://127.0.0.1:19530",
MILVUS_HOST: Optional[str] = Field(
description="Milvus host",
default=None,
)
MILVUS_TOKEN: Optional[str] = Field(
description="Milvus token",
default=None,
MILVUS_PORT: PositiveInt = Field(
description="Milvus RestFul API port",
default=9091,
)
MILVUS_USER: Optional[str] = Field(
@ -29,6 +29,11 @@ class MilvusConfig(BaseSettings):
default=None,
)
MILVUS_SECURE: bool = Field(
description="whether to use SSL connection for Milvus",
default=False,
)
MILVUS_DATABASE: str = Field(
description="Milvus database, default to `default`",
default="default",

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.8.2",
default="0.8.0-beta1",
)
COMMIT_SHA: str = Field(

File diff suppressed because one or more lines are too long

View File

@ -37,7 +37,7 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
from .billing import billing
# Import datasets controllers
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website
# Import explore controllers
from .explore import (

View File

@ -60,15 +60,23 @@ class InsertExploreAppListApi(Resource):
site = app.site
if not site:
desc = args["desc"] or ""
copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] or ""
desc = args["desc"] if args["desc"] else ""
copy_right = args["copyright"] if args["copyright"] else ""
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
else:
desc = site.description or args["desc"] or ""
copy_right = site.copyright or args["copyright"] or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
desc = site.description if site.description else args["desc"] if args["desc"] else ""
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
privacy_policy = (
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
)
custom_disclaimer = (
site.custom_disclaimer
if site.custom_disclaimer
else args["custom_disclaimer"]
if args["custom_disclaimer"]
else ""
)
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()

View File

@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id):
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor:
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = (

View File

@ -94,15 +94,19 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)

View File

@ -20,7 +20,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.helper import DatetimeString
from libs.helper import datetime_string
from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
)
@ -201,11 +201,7 @@ class ChatConversationApi(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
@ -214,11 +210,7 @@ class ChatConversationApi(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
query = query.where(Conversation.created_at < end_datetime_utc)
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join(

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.helper import datetime_string
from libs.login import login_required
from models.model import AppMode
@ -25,17 +25,14 @@ class DailyMessageStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(*) AS message_count
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -48,7 +45,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -58,10 +55,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@ -82,17 +79,14 @@ class DailyConversationStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT messages.conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -105,7 +99,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -115,10 +109,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@ -139,17 +133,14 @@ class DailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -162,7 +153,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -172,10 +163,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@ -196,18 +187,16 @@ class DailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
SUM(total_price) AS total_price
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
sum(total_price) as total_price
FROM messages where app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -220,7 +209,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -230,10 +219,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@ -256,26 +245,16 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(subquery.message_count) AS interactions
FROM
(
SELECT
m.conversation_id,
COUNT(m.id) AS message_count
FROM
conversations c
JOIN
messages m
ON c.id = m.conversation_id
WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(subquery.message_count) AS interactions
FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
FROM conversations c
JOIN messages m ON c.id = m.conversation_id
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -288,7 +267,7 @@ FROM
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND c.created_at >= :start"
sql_query += " and c.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -298,19 +277,14 @@ FROM
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND c.created_at < :end"
sql_query += " and c.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += """
GROUP BY m.conversation_id
) subquery
LEFT JOIN
conversations c
ON c.id = subquery.conversation_id
GROUP BY
date
ORDER BY
date"""
GROUP BY m.conversation_id) subquery
LEFT JOIN conversations c on c.id=subquery.conversation_id
GROUP BY date
ORDER BY date"""
response_data = []
@ -333,21 +307,17 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) AS message_count,
COUNT(mf.id) AS feedback_count
FROM
messages m
LEFT JOIN
message_feedbacks mf
ON mf.message_id=m.id AND mf.rating='like'
WHERE
m.app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
FROM messages m
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
WHERE m.app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -360,7 +330,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND m.created_at >= :start"
sql_query += " and m.created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -370,10 +340,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND m.created_at < :end"
sql_query += " and m.created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@ -399,17 +369,16 @@ class AverageResponseTimeStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) AS latency
FROM
messages
WHERE
app_id = :app_id"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) as latency
FROM messages
WHERE app_id = :app_id
"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -422,7 +391,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -432,10 +401,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@ -456,20 +425,17 @@ class TokensPerSecondStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
CASE
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
END as tokens_per_second
FROM
messages
WHERE
app_id = :app_id"""
FROM messages
WHERE app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
@ -482,7 +448,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -492,10 +458,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []

View File

@ -465,6 +465,6 @@ api.add_resource(
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
)
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.helper import datetime_string
from libs.login import login_required
from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom
@ -26,18 +26,16 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(id) AS runs
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@ -54,7 +52,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -64,10 +62,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@ -88,18 +86,16 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
sql_query = """
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@ -116,7 +112,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -126,10 +122,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@ -150,18 +146,18 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) AS token_count
FROM
workflow_runs
WHERE
app_id = :app_id
AND triggered_from = :triggered_from"""
sql_query = """
SELECT
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
SUM(workflow_runs.total_tokens) as token_count
FROM workflow_runs
WHERE app_id = :app_id
AND triggered_from = :triggered_from
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@ -178,7 +174,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
@ -188,10 +184,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
sql_query += " GROUP BY date order by date"
response_data = []
@ -217,31 +213,27 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
AVG(sub.interactions) AS interactions,
sub.date
FROM
(
SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM
workflow_runs c
WHERE
c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY
date, c.created_by
) sub
GROUP BY
sub.date"""
sql_query = """
SELECT
AVG(sub.interactions) as interactions,
sub.date
FROM
(SELECT
date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by,
COUNT(c.id) AS interactions
FROM workflow_runs c
WHERE c.app_id = :app_id
AND c.triggered_from = :triggered_from
{{start}}
{{end}}
GROUP BY date, c.created_by) sub
GROUP BY sub.date
"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
@ -270,7 +262,7 @@ GROUP BY
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
arg_dict["end"] = end_datetime_utc
else:
sql_query = sql_query.replace("{{end}}", "")

View File

@ -8,7 +8,7 @@ from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
from libs.helper import StrLen, email, timezone
from libs.helper import email, str_len, timezone
from libs.password import hash_password, valid_password
from models.account import AccountStatus
from services.account_service import RegisterService
@ -37,7 +37,7 @@ class ActivateApi(Resource):
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json"

View File

@ -71,7 +71,7 @@ class OAuthCallback(Resource):
account = _generate_account(provider, user_info)
# Check account status
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value:
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account:
# Create account
account_name = user_info.name or "Dify"
account_name = user_info.name if user_info.name else "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)

View File

@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
@ -110,26 +110,6 @@ class DatasetListApi(Resource):
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"external_api_template_id",
type=str,
nullable=True,
required=False,
)
parser.add_argument(
"provider",
type=str,
nullable=True,
choices=Dataset.PROVIDER_LIST,
required=False,
default="vendor",
)
parser.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -143,9 +123,6 @@ class DatasetListApi(Resource):
indexing_technique=args["indexing_technique"],
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_api_template_id=args["external_api_template_id"],
external_knowledge_id=args["external_knowledge_id"],
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
@ -422,7 +399,7 @@ class DatasetIndexingEstimateApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -573,7 +550,12 @@ class DatasetApiBaseUrlApi(Resource):
@login_required
@account_initialization_required
def get(self):
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
return {
"api_base_url": (
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
)
+ "/v1"
}
class DatasetRetrievalSettingApi(Resource):

View File

@ -302,8 +302,6 @@ class DatasetInitApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -311,8 +309,6 @@ class DatasetInitApi(Resource):
raise Forbidden()
if args["indexing_technique"] == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
model_manager.get_default_model_instance(
@ -354,7 +350,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in ["completed", "error"]:
raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule
@ -421,7 +417,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
info_list = []
extract_settings = []
for document in documents:
if document.indexing_status in {"completed", "error"}:
if document.indexing_status in ["completed", "error"]:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
# format document files info
@ -665,7 +661,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit()
elif action == "resume":
if document.indexing_status not in {"paused", "error"}:
if document.indexing_status not in ["paused", "error"]:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None

View File

@ -1,254 +0,0 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from services.external_knowledge_service import ExternalDatasetService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class ExternalApiTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
api_templates, total = ExternalDatasetService.get_external_api_templates(
page, limit, current_user.current_tenant_id, search
)
response = {
"data": [item.to_dict() for item in api_templates],
"has_more": len(api_templates) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
nullable=False,
required=True,
help="Description is required. Description must be between 1 to 400 characters.",
type=_validate_description_length,
)
parser.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
try:
api_template = ExternalDatasetService.create_api_template(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return api_template.to_dict(), 201
class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, api_template_id):
api_template_id = str(api_template_id)
api_template = ExternalDatasetService.get_api_template(api_template_id)
if api_template is None:
raise NotFound("API template not found.")
return api_template.to_dict(), 200
@setup_required
@login_required
@account_initialization_required
def patch(self, api_template_id):
api_template_id = str(api_template_id)
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
nullable=False,
required=True,
help="description is required. Description must be between 1 to 400 characters.",
type=_validate_description_length,
)
parser.add_argument(
"settings",
type=dict,
location="json",
nullable=False,
required=True,
)
args = parser.parse_args()
ExternalDatasetService.validate_api_list(args["settings"])
api_template = ExternalDatasetService.update_api_template(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
api_template_id=api_template_id,
args=args,
)
return api_template.to_dict(), 200
@setup_required
@login_required
@account_initialization_required
def delete(self, api_template_id):
api_template_id = str(api_template_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor or current_user.is_dataset_operator:
raise Forbidden()
ExternalDatasetService.delete_api_template(current_user.current_tenant_id, api_template_id)
return {"result": "success"}, 204
class ExternalApiUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, api_template_id):
api_template_id = str(api_template_id)
external_api_template_is_using = ExternalDatasetService.external_api_template_use_check(api_template_id)
return {"is_using": external_api_template_is_using}, 200
class ExternalDatasetInitApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("api_template_id", type=str, required=True, nullable=True, location="json")
# parser.add_argument('name', nullable=False, required=True,
# help='name is required. Name must be between 1 to 100 characters.',
# type=_validate_name)
# parser.add_argument('description', type=str, required=True, nullable=True, location='json')
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
parser.add_argument("process_parameter", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
# validate args
ExternalDatasetService.document_create_args_validate(
current_user.current_tenant_id, args["api_template_id"], args["process_parameter"]
)
try:
dataset, documents, batch = ExternalDatasetService.init_external_dataset(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
args=args,
)
except Exception as ex:
raise ProviderNotInitializeError(ex.description)
response = {"dataset": dataset, "documents": documents, "batch": batch}
return response
class ExternalDatasetCreateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("external_api_template_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
parser.add_argument(
"name",
nullable=False,
required=True,
help="name is required. Name must be between 1 to 100 characters.",
type=_validate_name,
)
parser.add_argument("description", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
try:
dataset = ExternalDatasetService.create_external_dataset(
tenant_id=current_user.current_tenant_id,
user_id=current_user.id,
args=args,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 201
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-api-template")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-api-template/<uuid:api_template_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-api-template/<uuid:api_template_id>/use-check")

View File

@ -47,7 +47,6 @@ class HitTestingApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrival_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@ -58,7 +57,6 @@ class HitTestingApi(Resource):
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrival_model"],
limit=10,
)

View File

@ -1,49 +0,0 @@
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from services.external_knowledge_service import ExternalDatasetService
class TestExternalApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"top_k",
nullable=False,
required=True,
type=int,
)
parser.add_argument(
"score_threshold",
nullable=False,
required=True,
type=float,
)
args = parser.parse_args()
result = ExternalDatasetService.test_external_knowledge_retrival(
args["top_k"], args["score_threshold"]
)
response = {
"data": [item.to_dict() for item in api_templates],
"has_more": len(api_templates) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrival-documents")

View File

@ -18,7 +18,9 @@ class NotSetupError(BaseHTTPException):
class NotInitValidateError(BaseHTTPException):
error_code = "not_init_validated"
description = "Init validation has not been completed yet. Please proceed with the init validation process first."
description = (
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
)
code = 401

View File

@ -81,15 +81,19 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)

View File

@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

View File

@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
def get(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
def post(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -31,7 +31,7 @@ class InstalledAppsListApi(Resource):
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at,
"editable": current_user.role in {"owner", "admin"},
"editable": current_user.role in ["owner", "admin"],
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
}
for installed_app in installed_apps

View File

@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
message_id = str(message_id)

View File

@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters."""
app_model = installed_app.app
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -4,7 +4,7 @@ from flask import session
from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import StrLen
from libs.helper import str_len
from models.model import DifySetup
from services.account_service import TenantService
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
raise AlreadySetupError()
parser = reqparse.RequestParser()
parser.add_argument("password", type=StrLen(30), required=True, location="json")
parser.add_argument("password", type=str_len(30), required=True, location="json")
input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"):

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import StrLen, email, get_remote_ip
from libs.helper import email, get_remote_ip, str_len
from libs.password import valid_password
from models.model import DifySetup
from services.account_service import RegisterService, TenantService
@ -40,7 +40,7 @@ class SetupApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=StrLen(30), required=True, location="json")
parser.add_argument("name", type=str_len(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args()

View File

@ -13,7 +13,7 @@ from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 50:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 50 characters.")
return name

View File

@ -218,7 +218,7 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>"
)
api.add_resource(

View File

@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args["provider_name"] or "",
args["provider_name"] if args["provider_name"] else "",
args["tool_name"],
args["credentials"],
args["parameters"],

View File

@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
raise TooManyFilesError()
extension = file.filename.split(".")[-1]
if extension.lower() not in {"svg", "png"}:
if extension.lower() not in ["svg", "png"]:
raise UnsupportedFileTypeError()
try:

View File

@ -64,8 +64,7 @@ def cloud_edition_billing_resource_check(resource: str):
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
source = request.args.get("source")
if source == "datasets":
abort(403, "The number of documents has reached the limit of your subscription.")

View File

@ -42,7 +42,7 @@ class AppParameterApi(Resource):
@marshal_with(parameters_fields)
def get(self, app_model: App):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -79,15 +79,19 @@ class TextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None
response = AudioService.transcript_tts(

View File

@ -96,7 +96,7 @@ class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationApi(Resource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
@marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
@marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -76,7 +76,7 @@ class MessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
try:

View File

@ -1,7 +1,6 @@
import logging
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
@ -23,12 +22,10 @@ from core.errors.error import (
)
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from services.app_generate_service import AppGenerateService
from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
@ -116,30 +113,6 @@ class WorkflowTaskStopApi(Resource):
return {"result": "success"}
class WorkflowAppLogApi(Resource):
@validate_app_token
@marshal_with(workflow_app_log_pagination_fields)
def get(self, app_model: App):
"""
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model, args=args
)
return workflow_app_log_pagination
api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
api.add_resource(WorkflowAppLogApi, "/workflows/logs")

View File

@ -82,26 +82,6 @@ class DatasetListApi(DatasetApiResource):
required=False,
nullable=False,
)
parser.add_argument(
"external_api_template_id",
type=str,
nullable=True,
required=False,
default="_validate_name",
)
parser.add_argument(
"provider",
type=str,
nullable=True,
required=False,
default="vendor",
)
parser.add_argument(
"external_knowledge_id",
type=str,
nullable=True,
required=False,
)
args = parser.parse_args()
try:
@ -111,9 +91,6 @@ class DatasetListApi(DatasetApiResource):
indexing_technique=args["indexing_technique"],
account=current_user,
permission=args["permission"],
provider=args["provider"],
external_api_template_id=args["external_api_template_id"],
external_knowledge_id=args["external_knowledge_id"],
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()

View File

@ -37,7 +37,7 @@ class SegmentApi(DatasetApiResource):
if not document:
raise NotFound("Document not found.")
if document.indexing_status != "completed":
raise NotFound("Document is not completed.")
raise NotFound("Document is already completed.")
if not document.enabled:
raise NotFound("Document is disabled.")
# check embedding model setting
@ -67,7 +67,7 @@ class SegmentApi(DatasetApiResource):
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
return {"error": "Segemtns is required"}, 400
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""

View File

@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()

View File

@ -78,15 +78,19 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception:
voice = None

View File

@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
def post(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
conversation_id = str(c_id)

View File

@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError()
parser = reqparse.RequestParser()
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotCompletionAppError()
message_id = str(message_id)

View File

@ -80,8 +80,7 @@ def _validate_web_sso_token(decoded, system_features, app_code):
if not source or source != "sso":
raise WebSSOAuthRequiredError()
# Check if SSO is not enforced for web, and if the token source is SSO,
# raise an error and redirect to normal passport login
# Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
source = decoded.get("token_source")
if source and source == "sso":

View File

@ -1 +1 @@
import core.moderation.base
import core.moderation.base

View File

@ -1,7 +1,6 @@
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Optional, Union, cast
@ -46,25 +45,22 @@ from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner):
def __init__(
self,
tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
queue_manager: AppQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
def __init__(self, tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
queue_manager: AppQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None:
"""
Agent runner
:param tenant_id: tenant id
@ -92,7 +88,9 @@ class BaseAgentRunner(AppRunner):
self.message = message
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.history_prompt_messages = self.organize_agent_history(
prompt_messages=prompt_messages or []
)
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
@ -113,16 +111,12 @@ class BaseAgentRunner(AppRunner):
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback,
hit_callback=hit_callback
)
# get how many agent thoughts have been created
self.agent_thought_count = (
db.session.query(MessageAgentThought)
.filter(
MessageAgentThought.message_id == self.message.id,
)
.count()
)
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
db.session.close()
# check if model supports stream tool call
@ -141,26 +135,25 @@ class BaseAgentRunner(AppRunner):
self.query = None
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(
self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
"""
Repack app generate entity
"""
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
return app_generate_entity
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
"""
convert tool to prompt message tool
convert tool to prompt message tool
"""
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from,
invoke_from=self.application_generate_entity.invoke_from
)
tool_entity.load_variables(self.variables_pool)
@ -171,7 +164,7 @@ class BaseAgentRunner(AppRunner):
"type": "object",
"properties": {},
"required": [],
},
}
)
parameters = tool_entity.get_all_runtime_parameters()
@ -184,19 +177,19 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
message_tool.parameters["properties"][parameter.name] = {
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
"description": parameter.llm_description or '',
}
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
message_tool.parameters["required"].append(parameter.name)
message_tool.parameters['required'].append(parameter.name)
return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
"""
convert dataset retriever tool to prompt message tool
@ -208,24 +201,24 @@ class BaseAgentRunner(AppRunner):
"type": "object",
"properties": {},
"required": [],
},
}
)
for parameter in tool.get_runtime_parameters():
parameter_type = "string"
prompt_tool.parameters["properties"][parameter.name] = {
parameter_type = 'string'
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
"description": parameter.llm_description or '',
}
if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
"""
Init tools
"""
@ -268,51 +261,51 @@ class BaseAgentRunner(AppRunner):
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
prompt_tool.parameters["properties"][parameter.name] = {
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
"description": parameter.llm_description or '',
}
if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
def create_agent_thought(
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
def create_agent_thought(self, message_id: str, message: str,
tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
"""
Create agent thought
"""
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
thought="",
thought='',
tool=tool_name,
tool_labels_str="{}",
tool_meta_str="{}",
tool_labels_str='{}',
tool_meta_str='{}',
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
message_files=json.dumps(messages_ids) if messages_ids else '',
answer='',
observation='',
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
tokens=0,
total_price=0,
position=self.agent_thought_count + 1,
currency="USD",
currency='USD',
latency=0,
created_by_role="account",
created_by_role='account',
created_by=self.user_id,
)
@ -325,22 +318,22 @@ class BaseAgentRunner(AppRunner):
return thought
def save_agent_thought(
self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
def save_agent_thought(self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None) -> MessageAgentThought:
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None:
agent_thought.thought = thought
@ -363,7 +356,7 @@ class BaseAgentRunner(AppRunner):
observation = json.dumps(observation, ensure_ascii=False)
except Exception as e:
observation = json.dumps(observation)
agent_thought.observation = observation
if answer is not None:
@ -371,7 +364,7 @@ class BaseAgentRunner(AppRunner):
if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids)
if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
@ -384,7 +377,7 @@ class BaseAgentRunner(AppRunner):
# check if tool labels is not empty
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else []
tools = agent_thought.tool.split(';') if agent_thought.tool else []
for tool in tools:
if not tool:
continue
@ -393,7 +386,7 @@ class BaseAgentRunner(AppRunner):
if tool_label:
labels[tool] = tool_label.to_dict()
else:
labels[tool] = {"en_US": tool, "zh_Hans": tool}
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
agent_thought.tool_labels_str = json.dumps(labels)
@ -408,18 +401,14 @@ class BaseAgentRunner(AppRunner):
db.session.commit()
db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
db_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
@ -436,14 +425,9 @@ class BaseAgentRunner(AppRunner):
if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message)
messages: list[Message] = (
db.session.query(Message)
.filter(
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.asc())
.all()
)
messages: list[Message] = db.session.query(Message).filter(
Message.conversation_id == self.message.conversation_id,
).order_by(Message.created_at.asc()).all()
for message in messages:
if message.id == self.message.id:
@ -455,13 +439,13 @@ class BaseAgentRunner(AppRunner):
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tools = tools.split(';')
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e:
tool_inputs = {tool: {} for tool in tools}
tool_inputs = { tool: {} for tool in tools }
try:
tool_responses = json.loads(agent_thought.observation)
except Exception as e:
@ -470,33 +454,27 @@ class BaseAgentRunner(AppRunner):
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
),
)
)
tool_call_response.append(
ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
tool_calls.append(AssistantPromptMessage.ToolCall(
id=tool_call_id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
tool_call_id=tool_call_id,
arguments=json.dumps(tool_inputs.get(tool, {})),
)
)
))
tool_call_response.append(ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
name=tool,
tool_call_id=tool_call_id,
))
result.extend(
[
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response,
]
)
result.extend([
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response
])
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
@ -518,7 +496,10 @@ class BaseAgentRunner(AppRunner):
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
file_objs = message_file_parser.transform_message_files(
files,
file_extra_config
)
else:
file_objs = []

View File

@ -25,19 +25,17 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_ignore_observation_providers = ['wenxin']
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
def run(
self,
message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
def run(self, message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
@ -48,16 +46,17 @@ class CotAgentRunner(BaseAgentRunner, ABC):
trace_manager = app_generate_entity.trace_manager
# check model mode
if "Observation" not in app_generate_entity.model_conf.stop:
if 'Observation' not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append("Observation")
app_generate_entity.model_conf.stop.append('Observation')
app_config = self.app_config
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(
instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@ -66,14 +65,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
llm_usage = {"usage": None}
final_answer = ""
llm_usage = {
'usage': None
}
final_answer = ''
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage = final_llm_usage_dict['usage']
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@ -93,13 +94,17 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
@ -120,20 +125,21 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm")
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
react_chunks = CotAgentOutputParser.handle_react_stream_output(
chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
agent_response='',
thought='',
action_str='',
observation='',
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
@ -148,51 +154,61 @@ class CotAgentRunner(BaseAgentRunner, ABC):
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
)
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
scratchpad.thought = scratchpad.thought.strip(
) or 'I am thinking about how to help you'
self._agent_scratchpad.append(scratchpad)
# get llm usage
if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict["usage"])
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
else:
usage_dict["usage"] = LLMUsage.empty_usage()
usage_dict['usage'] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought,
observation="",
observation='',
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=usage_dict["usage"],
llm_usage=usage_dict['usage']
)
if not scratchpad.is_final():
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ""
final_answer = ''
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input)
final_answer = json.dumps(
scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
final_answer = f'{scratchpad.action.action_input}'
except json.JSONDecodeError:
final_answer = f"{scratchpad.action.action_input}"
final_answer = f'{scratchpad.action.action_input}'
else:
function_call_state = True
# action is tool call, invoke tool
@ -208,18 +224,21 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
observation={
scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict["usage"],
llm_usage=usage_dict['usage']
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
@ -231,45 +250,44 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
index=0,
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
),
system_fingerprint="",
system_fingerprint=''
)
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name="",
tool_name='',
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[],
messages_ids=[]
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
),
PublishFrom.APPLICATION_MANAGER,
)
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
@ -308,12 +326,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file_id)
@ -323,7 +342,10 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
return AgentScratchpadUnit.Action(
action_name=action['action'],
action_input=action['action_input']
)
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
@ -331,7 +353,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
except Exception as e:
continue
@ -348,14 +370,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
format assistant message
"""
message = ""
message = ''
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
@ -368,11 +390,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
organize historic prompt messages
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
@ -383,8 +403,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
thought=message.content or 'I am thinking about how to help you',
action_str='',
action=None,
observation=None,
)
@ -393,9 +413,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
action_input=json.loads(
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except:
pass
elif isinstance(message, ToolPromptMessage):
@ -403,19 +426,23 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory,
memory=self.memory
).get_prompt()
return historic_prompts

View File

@ -19,15 +19,14 @@ class CotChatAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
system_prompt = first_prompt \
.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@ -44,7 +43,7 @@ class CotChatAgentRunner(CotAgentRunner):
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
@ -54,7 +53,7 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content="")
assistant_message = AssistantPromptMessage(content='')
for unit in agent_scratchpad:
if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}"
@ -72,15 +71,18 @@ class CotChatAgentRunner(CotAgentRunner):
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages(
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
)
historic_messages = self._organize_historic_prompt_messages([
system_message,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
messages = [
system_message,
*historic_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content="continue"),
UserPromptMessage(content='continue')
]
else:
# organize historic prompt messages

View File

@ -13,12 +13,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
@ -48,7 +46,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ""
assistant_prompt = ''
for unit in agent_scratchpad:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
@ -63,10 +61,9 @@ class CotCompletionAgentRunner(CotAgentRunner):
query_prompt = f"Question: {self._query}"
# join all messages
prompt = (
system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt)
prompt = system_prompt \
.replace("{{historic_messages}}", historic_prompt) \
.replace("{{agent_scratchpad}}", assistant_prompt) \
.replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)]
return [UserPromptMessage(content=prompt)]

View File

@ -8,7 +8,6 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api", "workflow"]
provider_id: str
tool_name: str
@ -19,7 +18,6 @@ class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
@ -33,7 +31,6 @@ class AgentScratchpadUnit(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
@ -42,8 +39,8 @@ class AgentScratchpadUnit(BaseModel):
Convert to dictionary.
"""
return {
"action": self.action_name,
"action_input": self.action_input,
'action': self.action_name,
'action_input': self.action_input,
}
agent_response: Optional[str] = None
@ -57,10 +54,10 @@ class AgentScratchpadUnit(BaseModel):
Check if the scratchpad unit is final.
"""
return self.action is None or (
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
'final' in self.action.action_name.lower() and
'answer' in self.action.action_name.lower()
)
class AgentEntity(BaseModel):
"""
Agent Entity.
@ -70,9 +67,8 @@ class AgentEntity(BaseModel):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = "chain-of-thought"
FUNCTION_CALLING = "function-calling"
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
provider: str
model: str

View File

@ -24,9 +24,11 @@ from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
@ -43,17 +45,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call
function_call_state = True
llm_usage = {"usage": None}
final_answer = ""
llm_usage = {
'usage': None
}
final_answer = ''
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage = final_llm_usage_dict['usage']
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@ -71,7 +75,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
)
# recalc llm max tokens
@ -91,11 +99,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ""
response = ''
# save tool call names and inputs
tool_call_names = ""
tool_call_inputs = ""
tool_call_names = ''
tool_call_inputs = ''
current_llm_usage = None
@ -103,22 +111,24 @@ class FunctionCallAgentRunner(BaseAgentRunner):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
@ -138,14 +148,16 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if result.usage:
increase_usage(llm_usage, result.usage)
@ -159,12 +171,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
response += result.message.content
if not result.message.content:
result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
result.message.content = ''
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
@ -173,29 +185,32 @@ class FunctionCallAgentRunner(BaseAgentRunner):
index=0,
message=result.message,
usage=result.usage,
),
)
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
assistant_message = AssistantPromptMessage(
content='',
tool_calls=[]
)
if tool_calls:
assistant_message.tool_calls = [
assistant_message.tool_calls=[
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type="function",
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
),
)
for tool_call in tool_calls
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
@ -203,13 +218,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
observation=None,
answer=response,
messages_ids=[],
llm_usage=current_llm_usage,
llm_usage=current_llm_usage
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
final_answer += response + "\n"
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
final_answer += response + '\n'
# call tools
tool_responses = []
@ -220,7 +235,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
}
else:
# invoke tool
@ -240,49 +255,50 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict(),
"meta": tool_invoke_meta.to_dict()
}
tool_responses.append(tool_response)
if tool_response["tool_response"] is not None:
if tool_response['tool_response'] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response["tool_response"],
content=tool_response['tool_response'],
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=None,
tool_input=None,
thought=None,
thought=None,
tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
tool_response['tool_call_name']: tool_response['meta']
for tool_response in tool_responses
},
observation={
tool_response["tool_call_name"]: tool_response["tool_response"]
tool_response['tool_call_name']: tool_response['tool_response']
for tool_response in tool_responses
},
answer=None,
messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
messages_ids=message_file_ids
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool
for prompt_tool in prompt_messages_tools:
@ -292,18 +308,15 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
),
PublishFrom.APPLICATION_MANAGER,
)
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
@ -312,7 +325,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
@ -321,9 +334,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True
return False
def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract tool calls from llm result chunk
@ -333,19 +344,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
if prompt_message.function.arguments != '':
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
@ -356,22 +365,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
if prompt_message.function.arguments != '':
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
return tool_calls
def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Initialize system message
"""
@ -379,13 +384,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return [
SystemPromptMessage(content=prompt_template),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@ -399,7 +404,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
@ -410,21 +415,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
prompt_message.content = '\n'.join([
content.data if content.type == PromptMessageContentType.TEXT else
'[image]' if content.type == PromptMessageContentType.IMAGE else
'[file]'
for content in prompt_message.content
])
return prompt_messages
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])
@ -432,10 +433,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
memory=self.memory
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
prompt_messages = [
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

View File

@ -9,9 +9,8 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str):
try:
action = json.loads(json_str)
@ -23,7 +22,7 @@ class CotAgentOutputParser:
action = action[0]
for key, value in action.items():
if "input" in key.lower():
if 'input' in key.lower():
action_input = value
else:
action_name = value
@ -34,37 +33,37 @@ class CotAgentOutputParser:
action_input=action_input,
)
else:
return json_str or ""
return json_str or ''
except:
return json_str or ""
return json_str or ''
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
yield parse_action(json_text)
code_block_cache = ""
code_block_cache = ''
code_block_delimiter_count = 0
in_code_block = False
json_cache = ""
json_cache = ''
json_quote_count = 0
in_json = False
got_json = False
action_cache = ""
action_str = "action:"
action_cache = ''
action_str = 'action:'
action_idx = 0
thought_cache = ""
thought_str = "thought:"
thought_cache = ''
thought_str = 'thought:'
thought_idx = 0
for response in llm_response:
if response.delta.usage:
usage_dict["usage"] = response.delta.usage
usage_dict['usage'] = response.delta.usage
response = response.delta.message.content
if not isinstance(response, str):
continue
@ -73,24 +72,24 @@ class CotAgentOutputParser:
index = 0
while index < len(response):
steps = 1
delta = response[index : index + steps]
last_character = response[index - 1] if index > 0 else ""
delta = response[index:index+steps]
last_character = response[index-1] if index > 0 else ''
if delta == "`":
if delta == '`':
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ""
code_block_cache = ''
else:
code_block_cache += delta
code_block_delimiter_count = 0
if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in {"\n", " ", ""}:
if last_character not in ['\n', ' ', '']:
index += steps
yield delta
continue
@ -98,7 +97,7 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ""
action_cache = ''
action_idx = 0
index += steps
continue
@ -106,18 +105,18 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ""
action_cache = ''
action_idx = 0
index += steps
continue
else:
if action_cache:
yield action_cache
action_cache = ""
action_cache = ''
action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in {"\n", " ", ""}:
if last_character not in ['\n', ' ', '']:
index += steps
yield delta
continue
@ -125,7 +124,7 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ""
thought_cache = ''
thought_idx = 0
index += steps
continue
@ -133,31 +132,31 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ""
thought_cache = ''
thought_idx = 0
index += steps
continue
else:
if thought_cache:
yield thought_cache
thought_cache = ""
thought_cache = ''
thought_idx = 0
if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ""
code_block_cache = ''
in_code_block = not in_code_block
code_block_delimiter_count = 0
if not in_code_block:
# handle single json
if delta == "{":
if delta == '{':
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == "}":
elif delta == '}':
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
@ -173,12 +172,12 @@ class CotAgentOutputParser:
if got_json:
got_json = False
yield parse_action(json_cache)
json_cache = ""
json_cache = ''
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
yield delta.replace("`", "")
yield delta.replace('`', '')
index += steps
@ -187,3 +186,4 @@ class CotAgentOutputParser:
if json_cache:
yield parse_action(json_cache)

View File

@ -41,8 +41,7 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
{{historic_messages}}
Question: {{query}}
{{agent_scratchpad}}
Thought:""" # noqa: E501
Thought:"""
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
Thought:"""
@ -87,20 +86,19 @@ Action:
```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
""" # noqa: E501
"""
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = {
"english": {
"chat": {
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
},
"completion": {
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
'english': {
'chat': {
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
},
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
}
}
}

View File

@ -26,24 +26,34 @@ class BaseAppConfigManager:
config_dict = dict(config_dict.items())
additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
config=config_dict
)
additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
config=config_dict,
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
)
additional_features.opening_statement, additional_features.suggested_questions = (
OpeningStatementConfigManager.convert(config=config_dict)
)
additional_features.opening_statement, additional_features.suggested_questions = \
OpeningStatementConfigManager.convert(
config=config_dict
)
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
config=config_dict
)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
config=config_dict
)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
config=config_dict
)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
config=config_dict
)
return additional_features

View File

@ -7,24 +7,25 @@ from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
if not sensitive_word_avoidance_dict:
return None
if sensitive_word_avoidance_dict.get("enabled"):
if sensitive_word_avoidance_dict.get('enabled'):
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config"),
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
)
else:
return None
@classmethod
def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
-> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}
config["sensitive_word_avoidance"] = {
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
@ -40,6 +41,10 @@ class SensitiveWordAvoidanceConfigManager:
typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
ModerationFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=sensitive_word_avoidance_config
)
return config, ["sensitive_word_avoidance"]

View File

@ -12,70 +12,67 @@ class AgentConfigManager:
:param config: model config args
"""
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot")
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode']:
if agent_strategy == "function_call":
agent_dict = config.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy in {"cot", "react"}:
elif agent_strategy == 'cot' or agent_strategy == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
if config['model']['provider'] == 'openai':
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get("tools", []):
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
continue
agent_tool_properties = {
"provider_type": tool["provider_type"],
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
'provider_type': tool['provider_type'],
'provider_id': tool['provider_id'],
'tool_name': tool['tool_name'],
'tool_parameters': tool.get('tool_parameters', {})
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router",
"router",
}:
agent_prompt = agent_dict.get("prompt", None) or {}
if 'strategy' in config['agent_mode'] and \
config['agent_mode']['strategy'] not in ['react_router', 'router']:
agent_prompt = agent_dict.get('prompt', None) or {}
# check model mode
model_mode = config.get("model", {}).get("mode", "completion")
if model_mode == "completion":
model_mode = config.get('model', {}).get('mode', 'completion')
if model_mode == 'completion':
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get(
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
),
next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
),
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['completion'][
'agent_scratchpad']),
)
else:
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get(
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
),
next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
),
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
)
return AgentEntity(
provider=config["model"]["provider"],
model=config["model"]["name"],
provider=config['model']['provider'],
model=config['model']['name'],
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get("max_iteration", 5),
max_iteration=agent_dict.get('max_iteration', 5)
)
return None

View File

@ -15,38 +15,39 @@ class DatasetConfigManager:
:param config: model config args
"""
dataset_ids = []
if "datasets" in config.get("dataset_configs", {}):
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
if 'datasets' in config.get('dataset_configs', {}):
datasets = config.get('dataset_configs', {}).get('datasets', {
'strategy': 'router',
'datasets': []
})
for dataset in datasets.get("datasets", []):
for dataset in datasets.get('datasets', []):
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != "dataset":
if len(keys) == 0 or keys[0] != 'dataset':
continue
dataset = dataset["dataset"]
dataset = dataset['dataset']
if "enabled" not in dataset or not dataset["enabled"]:
if 'enabled' not in dataset or not dataset['enabled']:
continue
dataset_id = dataset.get("id", None)
dataset_id = dataset.get('id', None)
if dataset_id:
dataset_ids.append(dataset_id)
if (
"agent_mode" in config
and config["agent_mode"]
and "enabled" in config["agent_mode"]
and config["agent_mode"]["enabled"]
):
agent_dict = config.get("agent_mode", {})
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode'] \
and config['agent_mode']['enabled']:
for tool in agent_dict.get("tools", []):
agent_dict = config.get('agent_mode', {})
for tool in agent_dict.get('tools', []):
keys = tool.keys()
if len(keys) == 1:
# old standard
key = list(tool.keys())[0]
if key != "dataset":
if key != 'dataset':
continue
tool_item = tool[key]
@ -54,28 +55,30 @@ class DatasetConfigManager:
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item["id"]
dataset_id = tool_item['id']
dataset_ids.append(dataset_id)
if len(dataset_ids) == 0:
return None
# dataset configs
if "dataset_configs" in config and config.get("dataset_configs"):
dataset_configs = config.get("dataset_configs")
if 'dataset_configs' in config and config.get('dataset_configs'):
dataset_configs = config.get('dataset_configs')
else:
dataset_configs = {"retrieval_model": "multiple"}
query_variable = config.get("dataset_query_variable")
dataset_configs = {
'retrieval_model': 'multiple'
}
query_variable = config.get('dataset_query_variable')
if dataset_configs["retrieval_model"] == "single":
if dataset_configs['retrieval_model'] == 'single':
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
),
),
dataset_configs['retrieval_model']
)
)
)
else:
return DatasetEntity(
@ -83,15 +86,15 @@ class DatasetConfigManager:
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"]
dataset_configs['retrieval_model']
),
top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get("score_threshold"),
reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
),
top_k=dataset_configs.get('top_k', 4),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
)
)
@classmethod
@ -108,10 +111,13 @@ class DatasetConfigManager:
# dataset_configs
if not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"}
config["dataset_configs"] = {'retrieval_model': 'single'}
if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
config["dataset_configs"]["datasets"] = {
"strategy": "router",
"datasets": []
}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
@ -119,9 +125,8 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {}
).get("datasets")
need_manual_query_datasets = (config.get("dataset_configs")
and config["dataset_configs"].get("datasets", {}).get("datasets"))
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
@ -143,7 +148,10 @@ class DatasetConfigManager:
"""
# Extract dataset config for legacy compatibility
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
config["agent_mode"] = {
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@ -167,7 +175,7 @@ class DatasetConfigManager:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0]
if key == "dataset":
@ -180,7 +188,7 @@ class DatasetConfigManager:
if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if "id" not in tool_item:
if 'id' not in tool_item:
raise ValueError("id is required in dataset")
try:

View File

@ -11,7 +11,9 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
def convert(cls, app_config: EasyUIBasedAppConfig,
skip_check: bool = False) \
-> ModelConfigWithCredentialsEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
@ -23,7 +25,9 @@ class ModelConfigConverter:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
tenant_id=app_config.tenant_id,
provider=model_config.provider,
model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
@ -34,7 +38,8 @@ class ModelConfigConverter:
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_config.model
model_type=ModelType.LLM,
model=model_config.model
)
if model_credentials is None:
@ -46,7 +51,8 @@ class ModelConfigConverter:
if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
model=model_config.model,
model_type=ModelType.LLM
)
if provider_model is None:
@ -63,18 +69,24 @@ class ModelConfigConverter:
# model config
completion_params = model_config.parameters
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = model_config.mode
if not model_mode:
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
mode_enum = model_type_instance.get_model_mode(
model=model_config.model,
credentials=model_credentials
)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
model_schema = model_type_instance.get_model_schema(
model_config.model,
model_credentials
)
if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@ -13,23 +13,23 @@ class ModelConfigManager:
:param config: model config args
"""
# model config
model_config = config.get("model")
model_config = config.get('model')
if not model_config:
raise ValueError("model is required")
completion_params = model_config.get("completion_params")
completion_params = model_config.get('completion_params')
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
# get model mode
model_mode = model_config.get("mode")
model_mode = model_config.get('mode')
return ModelConfigEntity(
provider=config["model"]["provider"],
model=config["model"]["name"],
provider=config['model']['provider'],
model=config['model']['name'],
mode=model_mode,
parameters=completion_params,
stop=stop,
@ -43,7 +43,7 @@ class ModelConfigManager:
:param tenant_id: tenant id
:param config: app model config args
"""
if "model" not in config:
if 'model' not in config:
raise ValueError("model is required")
if not isinstance(config["model"], dict):
@ -52,16 +52,17 @@ class ModelConfigManager:
# model.provider
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name
if "name" not in config["model"]:
if 'name' not in config["model"]:
raise ValueError("model.name is required")
provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"], model_type=ModelType.LLM
provider=config["model"]["provider"],
model_type=ModelType.LLM
)
if not models:
@ -79,12 +80,12 @@ class ModelConfigManager:
# model.mode
if model_mode:
config["model"]["mode"] = model_mode
config['model']["mode"] = model_mode
else:
config["model"]["mode"] = "completion"
config['model']["mode"] = "completion"
# model.completion_params
if "completion_params" not in config["model"]:
if 'completion_params' not in config["model"]:
raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = cls.validate_model_completion_params(
@ -100,7 +101,7 @@ class ModelConfigManager:
raise ValueError("model.completion_params must be of object type")
# stop
if "stop" not in cp:
if 'stop' not in cp:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")

View File

@ -14,33 +14,39 @@ class PromptTemplateConfigManager:
if not config.get("prompt_type"):
raise ValueError("prompt_type is required")
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = config.get("pre_prompt", "")
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
return PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else:
advanced_chat_prompt_template = None
chat_prompt_config = config.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append(
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
"prompt": completion_prompt_config["prompt"]["text"],
'prompt': completion_prompt_config['prompt']['text'],
}
if "conversation_histories_role" in completion_prompt_config:
completion_prompt_template_params["role_prefix"] = {
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
@ -50,7 +56,7 @@ class PromptTemplateConfigManager:
return PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
)
@classmethod
@ -66,7 +72,7 @@ class PromptTemplateConfigManager:
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config["prompt_type"] not in prompt_type_vals:
if config['prompt_type'] not in prompt_type_vals:
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
# chat_prompt_config
@ -83,28 +89,27 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
)
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required "
"when prompt_type is advanced")
model_mode_vals = [mode.value for mode in ModelMode]
if config["model"]["mode"] not in model_mode_vals:
if config['model']["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
if not user_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
if not assistant_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config["model"]["mode"] == ModelMode.CHAT.value:
prompt_list = config["chat_prompt_config"]["prompt"]
if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")

View File

@ -16,49 +16,51 @@ class BasicVariablesConfigManager:
variable_entities = []
# old external_data_tools
external_data_tools = config.get("external_data_tools", [])
external_data_tools = config.get('external_data_tools', [])
for external_data_tool in external_data_tools:
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=external_data_tool["variable"],
type=external_data_tool["type"],
config=external_data_tool["config"],
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
)
)
# variables and external_data_tools
for variables in config.get("user_input_form", []):
for variables in config.get('user_input_form', []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if "config" not in variable:
if 'config' not in variable:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=variable["variable"], type=variable["type"], config=variable["config"]
variable=variable['variable'],
type=variable['type'],
config=variable['config']
)
)
elif variable_type in {
elif variable_type in [
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
}:
]:
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description"),
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options"),
default=variable.get("default"),
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
)
)
@ -97,17 +99,17 @@ class BasicVariablesConfigManager:
variables = []
for item in config["user_input_form"]:
key = list(item.keys())[0]
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key]
if "label" not in form_item:
if 'label' not in form_item:
raise ValueError("label is required in user_input_form")
if not isinstance(form_item["label"], str):
raise ValueError("label in user_input_form must be of string type")
if "variable" not in form_item:
if 'variable' not in form_item:
raise ValueError("variable is required in user_input_form")
if not isinstance(form_item["variable"], str):
@ -115,24 +117,26 @@ class BasicVariablesConfigManager:
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, and cannot start with a number")
raise ValueError("variable in user_input_form must be a string, "
"and cannot start with a number")
variables.append(form_item["variable"])
if "required" not in form_item or not form_item["required"]:
if 'required' not in form_item or not form_item["required"]:
form_item["required"] = False
if not isinstance(form_item["required"], bool):
raise ValueError("required in user_input_form must be of boolean type")
if key == "select":
if "options" not in form_item or not form_item["options"]:
if 'options' not in form_item or not form_item["options"]:
form_item["options"] = []
if not isinstance(form_item["options"], list):
raise ValueError("options in user_input_form must be a list of strings")
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
if "default" in form_item and form_item['default'] \
and form_item["default"] not in form_item["options"]:
raise ValueError("default value in user_input_form must be in the options list")
return config, ["user_input_form"]
@ -164,6 +168,10 @@ class BasicVariablesConfigManager:
typ = tool["type"]
config = tool["config"]
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
ExternalDataToolFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=config
)
return config, ["external_data_tools"]

View File

@ -12,7 +12,6 @@ class ModelConfigEntity(BaseModel):
"""
Model Config Entity.
"""
provider: str
model: str
mode: Optional[str] = None
@ -24,7 +23,6 @@ class AdvancedChatMessageEntity(BaseModel):
"""
Advanced Chat Message Entity.
"""
text: str
role: PromptMessageRole
@ -33,7 +31,6 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
"""
Advanced Chat Prompt Template Entity.
"""
messages: list[AdvancedChatMessageEntity]
@ -46,7 +43,6 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
"""
Role Prefix Entity.
"""
user: str
assistant: str
@ -64,12 +60,11 @@ class PromptTemplateEntity(BaseModel):
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = "simple"
ADVANCED = "advanced"
SIMPLE = 'simple'
ADVANCED = 'advanced'
@classmethod
def value_of(cls, value: str) -> "PromptType":
def value_of(cls, value: str) -> 'PromptType':
"""
Get value of given mode.
@ -79,7 +74,7 @@ class PromptTemplateEntity(BaseModel):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid prompt type value {value}")
raise ValueError(f'invalid prompt type value {value}')
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
@ -92,7 +87,7 @@ class VariableEntityType(str, Enum):
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
EXTERNAL_DATA_TOOL = "external-data-tool"
class VariableEntity(BaseModel):
@ -115,7 +110,6 @@ class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
"""
variable: str
type: str
config: dict[str, Any] = {}
@ -131,12 +125,11 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = "single"
MULTIPLE = "multiple"
SINGLE = 'single'
MULTIPLE = 'multiple'
@classmethod
def value_of(cls, value: str) -> "RetrieveStrategy":
def value_of(cls, value: str) -> 'RetrieveStrategy':
"""
Get value of given mode.
@ -146,24 +139,25 @@ class DatasetRetrieveConfigEntity(BaseModel):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid retrieve strategy value {value}")
raise ValueError(f'invalid retrieve strategy value {value}')
query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None
score_threshold: Optional[float] = 0.0
rerank_mode: Optional[str] = "reranking_model"
score_threshold: Optional[float] = .0
rerank_mode: Optional[str] = 'reranking_model'
reranking_model: Optional[dict] = None
weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True
class DatasetEntity(BaseModel):
"""
Dataset Config Entity.
"""
dataset_ids: list[str]
retrieve_config: DatasetRetrieveConfigEntity
@ -172,7 +166,6 @@ class SensitiveWordAvoidanceEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
type: str
config: dict[str, Any] = {}
@ -181,7 +174,6 @@ class TextToSpeechEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
@ -191,11 +183,12 @@ class TracingConfigEntity(BaseModel):
"""
Tracing Config Entity.
"""
enabled: bool
tracing_provider: str
class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileExtraConfig] = None
opening_statement: Optional[str] = None
@ -207,12 +200,10 @@ class AppAdditionalFeatures(BaseModel):
text_to_speech: Optional[TextToSpeechEntity] = None
trace_config: Optional[TracingConfigEntity] = None
class AppConfig(BaseModel):
"""
Application Config Entity.
"""
tenant_id: str
app_id: str
app_mode: AppMode
@ -225,17 +216,15 @@ class EasyUIBasedAppModelConfigFrom(Enum):
"""
App Model Config From.
"""
ARGS = "args"
APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
ARGS = 'args'
APP_LATEST_CONFIG = 'app-latest-config'
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
class EasyUIBasedAppConfig(AppConfig):
"""
Easy UI Based App Config Entity.
"""
app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str
app_model_config_dict: dict
@ -249,5 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
"""
Workflow UI Based App Config Entity.
"""
workflow_id: str

View File

@ -13,19 +13,21 @@ class FileUploadConfigManager:
:param config: model config args
:param is_vision: if True, the feature is vision feature
"""
file_upload_dict = config.get("file_upload")
file_upload_dict = config.get('file_upload')
if file_upload_dict:
if file_upload_dict.get("image"):
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
if file_upload_dict.get('image'):
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
image_config = {
"number_limits": file_upload_dict["image"]["number_limits"],
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
'number_limits': file_upload_dict['image']['number_limits'],
'transfer_methods': file_upload_dict['image']['transfer_methods']
}
if is_vision:
image_config["detail"] = file_upload_dict["image"]["detail"]
image_config['detail'] = file_upload_dict['image']['detail']
return FileExtraConfig(image_config=image_config)
return FileExtraConfig(
image_config=image_config
)
return None
@ -47,21 +49,21 @@ class FileUploadConfigManager:
if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False}
if config["file_upload"]["image"]["enabled"]:
number_limits = config["file_upload"]["image"]["number_limits"]
if config['file_upload']['image']['enabled']:
number_limits = config['file_upload']['image']['number_limits']
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
if is_vision:
detail = config["file_upload"]["image"]["detail"]
if detail not in {"high", "low"}:
detail = config['file_upload']['image']['detail']
if detail not in ['high', 'low']:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
transfer_methods = config['file_upload']['image']['transfer_methods']
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in {"remote_url", "local_file"}:
if method not in ['remote_url', 'local_file']:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"]

View File

@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
:param config: model config args
"""
more_like_this = False
more_like_this_dict = config.get("more_like_this")
more_like_this_dict = config.get('more_like_this')
if more_like_this_dict:
if more_like_this_dict.get("enabled"):
if more_like_this_dict.get('enabled'):
more_like_this = True
return more_like_this
@ -22,7 +22,9 @@ class MoreLikeThisConfigManager:
:param config: app model config args
"""
if not config.get("more_like_this"):
config["more_like_this"] = {"enabled": False}
config["more_like_this"] = {
"enabled": False
}
if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type")

View File

@ -1,3 +1,5 @@
class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[str, list]:
@ -7,10 +9,10 @@ class OpeningStatementConfigManager:
:param config: model config args
"""
# opening statement
opening_statement = config.get("opening_statement")
opening_statement = config.get('opening_statement')
# suggested questions
suggested_questions_list = config.get("suggested_questions")
suggested_questions_list = config.get('suggested_questions')
return opening_statement, suggested_questions_list

View File

@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get("retriever_resource")
retriever_resource_dict = config.get('retriever_resource')
if retriever_resource_dict:
if retriever_resource_dict.get("enabled"):
if retriever_resource_dict.get('enabled'):
show_retrieve_source = True
return show_retrieve_source
@ -17,7 +17,9 @@ class RetrievalResourceConfigManager:
:param config: app model config args
"""
if not config.get("retriever_resource"):
config["retriever_resource"] = {"enabled": False}
config["retriever_resource"] = {
"enabled": False
}
if not isinstance(config["retriever_resource"], dict):
raise ValueError("retriever_resource must be of dict type")

View File

@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
:param config: model config args
"""
speech_to_text = False
speech_to_text_dict = config.get("speech_to_text")
speech_to_text_dict = config.get('speech_to_text')
if speech_to_text_dict:
if speech_to_text_dict.get("enabled"):
if speech_to_text_dict.get('enabled'):
speech_to_text = True
return speech_to_text
@ -22,7 +22,9 @@ class SpeechToTextConfigManager:
:param config: app model config args
"""
if not config.get("speech_to_text"):
config["speech_to_text"] = {"enabled": False}
config["speech_to_text"] = {
"enabled": False
}
if not isinstance(config["speech_to_text"], dict):
raise ValueError("speech_to_text must be of dict type")

View File

@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: model config args
"""
suggested_questions_after_answer = False
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
if suggested_questions_after_answer_dict:
if suggested_questions_after_answer_dict.get("enabled"):
if suggested_questions_after_answer_dict.get('enabled'):
suggested_questions_after_answer = True
return suggested_questions_after_answer
@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: app model config args
"""
if not config.get("suggested_questions_after_answer"):
config["suggested_questions_after_answer"] = {"enabled": False}
config["suggested_questions_after_answer"] = {
"enabled": False
}
if not isinstance(config["suggested_questions_after_answer"], dict):
raise ValueError("suggested_questions_after_answer must be of dict type")
if (
"enabled" not in config["suggested_questions_after_answer"]
or not config["suggested_questions_after_answer"]["enabled"]
):
if "enabled" not in config["suggested_questions_after_answer"] or not \
config["suggested_questions_after_answer"]["enabled"]:
config["suggested_questions_after_answer"]["enabled"] = False
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):

View File

@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
:param config: model config args
"""
text_to_speech = None
text_to_speech_dict = config.get("text_to_speech")
text_to_speech_dict = config.get('text_to_speech')
if text_to_speech_dict:
if text_to_speech_dict.get("enabled"):
if text_to_speech_dict.get('enabled'):
text_to_speech = TextToSpeechEntity(
enabled=text_to_speech_dict.get("enabled"),
voice=text_to_speech_dict.get("voice"),
language=text_to_speech_dict.get("language"),
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get('language'),
)
return text_to_speech
@ -29,7 +29,11 @@ class TextToSpeechConfigManager:
:param config: app model config args
"""
if not config.get("text_to_speech"):
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
config["text_to_speech"] = {
"enabled": False,
"voice": "",
"language": ""
}
if not isinstance(config["text_to_speech"], dict):
raise ValueError("text_to_speech must be of dict type")

View File

@ -1,3 +1,4 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
@ -18,13 +19,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
"""
Advanced Chatbot App Config Entity.
"""
pass
class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
def get_app_config(cls, app_model: App,
workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
@ -33,9 +34,13 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
additional_features=cls.convert_features(features_dict, app_mode),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
)
return app_config
@ -53,7 +58,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config, is_vision=False
config=config,
is_vision=False
)
related_config_keys.extend(current_related_config_keys)
@ -63,8 +69,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config
)
config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@ -81,7 +86,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
@ -91,3 +98,4 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
@ -34,8 +34,7 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
@ -45,8 +44,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
@ -55,14 +53,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -73,37 +71,44 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get("query"):
raise ValueError("query is required")
if not args.get('query'):
raise ValueError('query is required')
query = args["query"]
query = args['query']
if not isinstance(query, str):
raise ValueError("query must be a string")
raise ValueError('query must be a string')
query = query.replace("\x00", "")
inputs = args["inputs"]
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', False)
}
# get conversation
conversation = None
conversation_id = args.get("conversation_id")
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(
app_model=app_model, conversation_id=conversation_id, user=user
)
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# parse files
files = args["files"] if args.get("files") else []
files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
@ -125,7 +130,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
trace_manager=trace_manager
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -135,12 +140,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream,
stream=stream
)
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
) -> dict[str, Any] | Generator[str, Any, None]:
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -152,13 +161,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream
"""
if not node_id:
raise ValueError("node_id is required")
raise ValueError('node_id is required')
if args.get("inputs") is None:
raise ValueError("inputs is required")
if args.get('inputs') is None:
raise ValueError('inputs is required')
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
@ -166,15 +178,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_config=app_config,
conversation_id=None,
inputs={},
query="",
query='',
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
extras={
"auto_generate_conversation_name": False
},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"]
),
node_id=node_id,
inputs=args['inputs']
)
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -184,19 +199,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=None,
stream=stream,
stream=stream
)
def _generate(
self,
*,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -212,7 +225,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
is_first_conversation = True
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
@ -227,21 +243,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": contextvars.copy_context(),
},
)
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'context': contextvars.copy_context(),
})
worker_thread.start()
@ -256,17 +269,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
return AdvancedChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context,
) -> None:
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@ -289,21 +302,22 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
message=message
)
runner.run()
except GenerateTaskStoppedError:
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG", "false").lower() == "true":
if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
@ -349,7 +363,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
return generate_task_pipeline.process()
except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
raise GenerateTaskStoppedException()
else:
logger.exception(e)
raise e

View File

@ -21,11 +21,14 @@ class AudioTrunk:
self.status = status
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
content_text=text_content.strip(),
user="responding_tts",
tenant_id=tenant_id,
voice=voice
)
@ -41,26 +44,28 @@ def _process_future(future_queue, audio_queue):
except Exception as e:
logging.getLogger(__name__).warning(e)
break
audio_queue.put(AudioTrunk("finish", b""))
audio_queue.put(AudioTrunk("finish", b''))
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
self.msg_text = ''
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.match = re.compile(r'[。.!?]')
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS
tenant_id=self.tenant_id,
model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices()
values = [voice.get("value") for voice in self.voices]
values = [voice.get('value') for voice in self.voices]
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get("value")
self.voice = self.voices[0].get('value')
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
@ -80,9 +85,8 @@ class AppGeneratorTTSPublisher:
message = self._msg_queue.get()
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
)
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
self.model_instance, self.tenant_id, self.voice)
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
@ -90,27 +94,28 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
self.msg_text += message.event.outputs.get("output", "")
self.msg_text += message.event.outputs.get('output', '')
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
)
text_content = ''.join(sentence_arr)
futures_result = self.executor.submit(_invoiceTTS, text_content,
self.model_instance,
self.tenant_id,
self.voice)
future_queue.put(futures_result)
if text_tmp:
self.msg_text = text_tmp
else:
self.msg_text = ""
self.msg_text = ''
except Exception as e:
self.logger.warning(e)
break
future_queue.put(None)
def check_and_get_audio(self) -> AudioTrunk | None:
def checkAndGetAudio(self) -> AudioTrunk | None:
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:

View File

@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationError
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
@ -38,11 +38,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
"""
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message
) -> None:
"""
:param application_generate_entity: application generate entity
@ -66,14 +66,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
raise ValueError('Workflow not initialized')
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@ -81,7 +81,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
if self.application_generate_entity.single_iteration_run:
@ -89,7 +89,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
@ -98,27 +98,26 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id,
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
@ -175,7 +174,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
@ -191,12 +190,12 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._handle_event(workflow_entry, event)
def handle_input_moderation(
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
) -> bool:
"""
Handle input moderation
@ -217,15 +216,19 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query,
message_id=message_id,
)
except ModerationError as e:
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
except ModerationException as e:
self._complete_with_stream_output(
text=str(e),
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
return True
return False
def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
) -> bool:
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
"""
Handle annotation reply
:param app_record: app record
@ -243,21 +246,32 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
)
if annotation_reply:
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
self._publish_event(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
)
self._complete_with_stream_output(
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
text=annotation_reply.content,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
def _complete_with_stream_output(self,
text: str,
stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param text: text
:return:
"""
self._publish_event(QueueTextChunkEvent(text=text))
self._publish_event(
QueueTextChunkEvent(
text=text
)
)
self._publish_event(QueueStopEvent(stopped_by=stopped_by))
self._publish_event(
QueueStopEvent(stopped_by=stopped_by)
)

View File

@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
}
return response
@ -50,15 +50,13 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -69,14 +67,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
yield 'ping'
continue
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@ -87,9 +85,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -100,20 +96,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
yield 'ping'
continue
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -65,7 +65,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
@ -73,14 +72,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@ -124,10 +123,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._application_generate_entity.query
self._conversation,
self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
@ -145,7 +147,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
extras["metadata"] = stream_response.metadata
extras['metadata'] = stream_response.metadata
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
@ -156,17 +158,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
message_id=self._message.id,
answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()),
**extras,
),
**extras
)
)
else:
continue
raise Exception("Queue listening stopped unexpectedly.")
raise Exception('Queue listening stopped unexpectedly.')
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
"""
To stream response.
:return:
@ -176,35 +176,32 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation.id,
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response,
stream_response=stream_response
)
def _listen_audio_msg(self, publisher, task_id: str):
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -217,7 +214,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.check_and_get_audio()
audio_trunk = tts_publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -231,12 +228,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response(
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@ -270,18 +267,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
db.session.close()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise Exception('Workflow run not initialized.')
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
workflow_node_execution=workflow_node_execution
)
if response:
@ -292,7 +293,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
workflow_node_execution=workflow_node_execution
)
if response:
@ -303,52 +304,62 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
workflow_node_execution=workflow_node_execution
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
@ -361,16 +372,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
@ -384,10 +399,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
@ -404,7 +420,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
# Save message
@ -417,9 +434,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message()
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
@ -429,9 +445,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._refetch_message()
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
@ -451,15 +466,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(
answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
)
yield self._message_to_stream_response(delta_text, self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
raise Exception('Graph runtime state not initialized.')
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
@ -489,9 +502,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
@ -511,7 +523,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
extras=self._application_generate_entity.extras
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@ -521,13 +533,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.copy()
extras['metadata'] = self._task_state.metadata.copy()
if "annotation_reply" in extras["metadata"]:
del extras["metadata"]["annotation_reply"]
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
)
def _handle_output_moderation_chunk(self, text: str) -> bool:
@ -541,11 +555,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
)
return True
else:

View File

@ -28,19 +28,15 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
"""
Agent Chatbot App Config Entity.
"""
agent: Optional[AgentEntity] = None
class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> AgentChatAppConfig:
def get_app_config(cls, app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
:param app_model: app model
@ -70,12 +66,22 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
agent=AgentConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
agent=AgentConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -122,8 +128,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config
)
config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@ -140,15 +145,13 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# dataset configs
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
@ -167,7 +170,10 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
:param config: app model config args
"""
if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []}
config["agent_mode"] = {
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@ -181,9 +187,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [
member.value for member in list(PlanningStrategy.__members__.values())
]:
if config["agent_mode"]["strategy"] not in [member.value for member in
list(PlanningStrategy.__members__.values())]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"):
@ -205,7 +210,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "dataset":
if "id" not in tool_item:
if 'id' not in tool_item:
raise ValueError("id is required in dataset")
try:

View File

@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
@ -30,8 +30,7 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
self, app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
@ -40,17 +39,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
self, app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[dict, None, None]]:
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@ -61,48 +62,60 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream
"""
if not stream:
raise ValueError("Agent Chat App does not support blocking mode")
raise ValueError('Agent Chat App does not support blocking mode')
if not args.get("query"):
raise ValueError("query is required")
if not args.get('query'):
raise ValueError('query is required')
query = args["query"]
query = args['query']
if not isinstance(query, str):
raise ValueError("query must be a string")
raise ValueError('query must be a string')
query = query.replace("\x00", "")
inputs = args["inputs"]
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
# get conversation
conversation = None
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
# validate override model config
override_model_config_dict = None
if args.get("model_config"):
if args.get('model_config'):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError("Only in App debug mode can override model config")
raise ValueError('Only in App debug mode can override model config')
# validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id,
config=args.get('model_config')
)
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
override_model_config_dict["retriever_resource"] = {
"enabled": True
}
# parse files
files = args["files"] if args.get("files") else []
files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
else:
file_objs = []
@ -111,7 +124,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
override_config_dict=override_model_config_dict
)
# get tracing instance
@ -132,11 +145,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
extras=extras,
call_depth=0,
trace_manager=trace_manager,
trace_manager=trace_manager
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -145,20 +161,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
@ -172,11 +185,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
return AgentChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
def _generate_worker(
self,
flask_app: Flask,
self, flask_app: Flask,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
@ -205,17 +220,18 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
)
except GenerateTaskStoppedError:
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought
@ -30,8 +30,7 @@ class AgentChatAppRunner(AppRunner):
"""
def run(
self,
application_generate_entity: AgentChatAppGenerateEntity,
self, application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
@ -66,7 +65,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
query=query
)
memory = None
@ -74,10 +73,13 @@ class AgentChatAppRunner(AppRunner):
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model,
model=application_generate_entity.model_conf.model
)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@ -89,7 +91,7 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
memory=memory,
memory=memory
)
# moderation
@ -101,15 +103,15 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
message_id=message.id
)
except ModerationError as e:
except ModerationException as e:
self.direct_output(
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream,
stream=application_generate_entity.stream
)
return
@ -120,13 +122,13 @@ class AgentChatAppRunner(AppRunner):
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
invoke_from=application_generate_entity.invoke_from
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER,
PublishFrom.APPLICATION_MANAGER
)
self.direct_output(
@ -134,7 +136,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream,
stream=application_generate_entity.stream
)
return
@ -146,7 +148,7 @@ class AgentChatAppRunner(AppRunner):
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query,
query=query
)
# reorganize all inputs and template to prompt messages
@ -159,14 +161,14 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
memory=memory,
memory=memory
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages,
prompt_messages=prompt_messages
)
if hosting_moderation_result:
@ -175,9 +177,9 @@ class AgentChatAppRunner(AppRunner):
agent_entity = app_config.agent
# load tool variables
tool_conversation_variables = self._load_tool_variables(
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
)
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id)
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
@ -185,7 +187,7 @@ class AgentChatAppRunner(AppRunner):
# init model instance
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model,
model=application_generate_entity.model_conf.model
)
prompt_message, _ = self.organize_prompt_messages(
app_record=app_record,
@ -236,7 +238,7 @@ class AgentChatAppRunner(AppRunner):
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance,
model_instance=model_instance
)
invoke_result = runner.run(
@ -250,21 +252,17 @@ class AgentChatAppRunner(AppRunner):
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
agent=True,
agent=True
)
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id
).first()
if tool_variables:
# save tool variables to session, so that we can update it later
@ -275,40 +273,34 @@ class AgentChatAppRunner(AppRunner):
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
variables_str="[]",
variables_str='[]',
)
db.session.add(tool_variables)
db.session.commit()
return tool_variables
def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
"""
convert db variables to tool variables
"""
return ToolRuntimeVariablePool(
**{
"conversation_id": db_variables.conversation_id,
"user_id": db_variables.user_id,
"tenant_id": db_variables.tenant_id,
"pool": db_variables.variables,
}
)
return ToolRuntimeVariablePool(**{
'conversation_id': db_variables.conversation_id,
'user_id': db_variables.user_id,
'tenant_id': db_variables.tenant_id,
'pool': db_variables.variables
})
def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage:
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
message: Message) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
)
agent_thoughts = (db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == message.id).all())
all_message_tokens = 0
all_answer_tokens = 0
@ -320,5 +312,8 @@ class AgentChatAppRunner(AppRunner):
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
model_config.model,
model_config.credentials,
all_message_tokens,
all_answer_tokens
)

View File

@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return:
"""
response = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
}
return response
@ -45,15 +45,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@ -64,14 +63,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
yield 'ping'
continue
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@ -82,9 +81,8 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@ -95,20 +93,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
yield 'ping'
continue
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -13,33 +13,32 @@ class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
def convert(cls, response: Union[
AppBlockingResponse,
Generator[AppStreamResponse, Any, None]
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_full_response(response):
if chunk == "ping":
yield f"event: {chunk}\n\n"
if chunk == 'ping':
yield f'event: {chunk}\n\n'
else:
yield f"data: {chunk}\n\n"
yield f'data: {chunk}\n\n'
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_simple_response(response):
if chunk == "ping":
yield f"event: {chunk}\n\n"
if chunk == 'ping':
yield f'event: {chunk}\n\n'
else:
yield f"data: {chunk}\n\n"
yield f'data: {chunk}\n\n'
return _generate_simple_response()
@ -55,16 +54,14 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
-> Generator[str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
-> Generator[str, None, None]:
raise NotImplementedError
@classmethod
@ -75,26 +72,24 @@ class AppGenerateResponseConverter(ABC):
:return:
"""
# show_retrieve_source
if "retriever_resources" in metadata:
metadata["retriever_resources"] = []
for resource in metadata["retriever_resources"]:
metadata["retriever_resources"].append(
{
"segment_id": resource["segment_id"],
"position": resource["position"],
"document_name": resource["document_name"],
"score": resource["score"],
"content": resource["content"],
}
)
if 'retriever_resources' in metadata:
metadata['retriever_resources'] = []
for resource in metadata['retriever_resources']:
metadata['retriever_resources'].append({
'segment_id': resource['segment_id'],
'position': resource['position'],
'document_name': resource['document_name'],
'score': resource['score'],
'content': resource['content'],
})
# show annotation reply
if "annotation_reply" in metadata:
del metadata["annotation_reply"]
if 'annotation_reply' in metadata:
del metadata['annotation_reply']
# show usage
if "usage" in metadata:
del metadata["usage"]
if 'usage' in metadata:
del metadata['usage']
return metadata
@ -106,16 +101,16 @@ class AppGenerateResponseConverter(ABC):
:return:
"""
error_responses = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
ValueError: {'code': 'invalid_param', 'status': 400},
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
QuotaExceededError: {
"code": "provider_quota_exceeded",
"message": "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
"status": 400,
'code': 'provider_quota_exceeded',
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
'status': 400
},
ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
InvokeError: {"code": "completion_request_error", "status": 400},
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
InvokeError: {'code': 'completion_request_error', 'status': 400}
}
# Determine the response based on the type of exception
@ -125,13 +120,13 @@ class AppGenerateResponseConverter(ABC):
data = v
if data:
data.setdefault("message", getattr(e, "description", str(e)))
data.setdefault('message', getattr(e, 'description', str(e)))
else:
logging.error(e)
data = {
"code": "internal_server_error",
"message": "Internal Server Error, please contact support.",
"status": 500,
'code': 'internal_server_error',
'message': 'Internal Server Error, please contact support.',
'status': 500
}
return data

View File

@ -16,17 +16,17 @@ class BaseAppGenerator:
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f"{var.variable} is required in input form")
raise ValueError(f'{var.variable} is required in input form')
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ""
return var.default or ''
if (
var.type
in {
in (
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
}
)
and user_input_value
and not isinstance(user_input_value, str)
):
@ -34,7 +34,7 @@ class BaseAppGenerator:
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if "." in user_input_value:
if '.' in user_input_value:
return float(user_input_value)
else:
return int(user_input_value)
@ -43,14 +43,14 @@ class BaseAppGenerator:
if var.type == VariableEntityType.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
return user_input_value
def _sanitize_value(self, value: Any) -> Any:
if isinstance(value, str):
return value.replace("\x00", "")
return value.replace('\x00', '')
return value

View File

@ -24,7 +24,9 @@ class PublishFrom(Enum):
class AppQueueManager:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom) -> None:
if not user_id:
raise ValueError("user is required")
@ -32,10 +34,9 @@ class AppQueueManager:
self._user_id = user_id
self._invoke_from = invoke_from
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
redis_client.setex(
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800,
f"{user_prefix}-{self._user_id}")
q = queue.Queue()
@ -65,7 +66,8 @@ class AppQueueManager:
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
PublishFrom.TASK_PIPELINE
)
if elapsed_time // 10 > last_ping_time:
@ -86,7 +88,9 @@ class AppQueueManager:
:param pub_from: publish from
:return:
"""
self.publish(QueueErrorEvent(error=e), pub_from)
self.publish(QueueErrorEvent(
error=e
), pub_from)
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
@ -118,8 +122,8 @@ class AppQueueManager:
if result is None:
return
user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
@ -164,12 +168,10 @@ class AppQueueManager:
for item in data:
self._check_for_sqlalchemy_models(item)
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed."
)
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed.")
class GenerateTaskStoppedError(Exception):
class GenerateTaskStoppedException(Exception):
pass

View File

@ -31,15 +31,12 @@ if TYPE_CHECKING:
class AppRunner:
def get_pre_calculate_rest_tokens(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
) -> int:
def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None) -> int:
"""
Get pre calculate rest tokens
:param app_record: app record
@ -52,20 +49,18 @@ class AppRunner:
"""
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return -1
@ -80,39 +75,36 @@ class AppRunner:
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
query=query,
query=query
)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
return rest_tokens
def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage]):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if model_context_tokens is None:
return -1
@ -120,28 +112,27 @@ class AppRunner:
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
def organize_prompt_messages(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Organize prompt messages
:param context:
@ -161,54 +152,60 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query or "",
query=query if query else '',
files=files,
context=context,
memory=memory,
model_config=model_config,
model_config=model_config
)
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
memory_config = MemoryConfig(
window=MemoryConfig.WindowConfig(
enabled=False
)
)
model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
prompt_template = CompletionModelPromptTemplate(
text=advanced_completion_prompt_template.prompt
)
if advanced_completion_prompt_template.role_prefix:
memory_config.role_prefix = MemoryConfig.RolePrefix(
user=advanced_completion_prompt_template.role_prefix.user,
assistant=advanced_completion_prompt_template.role_prefix.assistant,
assistant=advanced_completion_prompt_template.role_prefix.assistant
)
else:
prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
prompt_template.append(ChatModelMessage(
text=message.text,
role=message.role
))
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=query or "",
query=query if query else '',
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_config=model_config
)
stop = model_config.stop
return prompt_messages, stop
def direct_output(
self,
queue_manager: AppQueueManager,
app_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None,
) -> None:
def direct_output(self, queue_manager: AppQueueManager,
app_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None) -> None:
"""
Direct output
:param queue_manager: application queue manager
@ -225,10 +222,17 @@ class AppRunner:
chunk = LLMResultChunk(
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=token)
)
)
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueLLMChunkEvent(
chunk=chunk
), PublishFrom.APPLICATION_MANAGER
)
index += 1
time.sleep(0.01)
@ -238,19 +242,15 @@ class AppRunner:
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage or LLMUsage.empty_usage(),
usage=usage if usage else LLMUsage.empty_usage()
),
),
PublishFrom.APPLICATION_MANAGER,
), PublishFrom.APPLICATION_MANAGER
)
def _handle_invoke_result(
self,
invoke_result: Union[LLMResult, Generator],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
) -> None:
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
@ -260,13 +260,21 @@ class AppRunner:
:return:
"""
if not stream:
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
else:
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
) -> None:
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
queue_manager: AppQueueManager,
agent: bool) -> None:
"""
Handle invoke result direct
:param invoke_result: invoke result
@ -277,13 +285,12 @@ class AppRunner:
queue_manager.publish(
QueueMessageEndEvent(
llm_result=invoke_result,
),
PublishFrom.APPLICATION_MANAGER,
), PublishFrom.APPLICATION_MANAGER
)
def _handle_invoke_result_stream(
self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
) -> None:
def _handle_invoke_result_stream(self, invoke_result: Generator,
queue_manager: AppQueueManager,
agent: bool) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
@ -293,13 +300,21 @@ class AppRunner:
"""
model = None
prompt_messages = []
text = ""
text = ''
usage = None
for result in invoke_result:
if not agent:
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueLLMChunkEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
else:
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
queue_manager.publish(
QueueAgentMessageEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
text += result.delta.message.content
@ -316,24 +331,25 @@ class AppRunner:
usage = LLMUsage.empty_usage()
llm_result = LLMResult(
model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage
)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=llm_result,
),
PublishFrom.APPLICATION_MANAGER,
), PublishFrom.APPLICATION_MANAGER
)
def moderation_for_inputs(
self,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
@ -351,17 +367,14 @@ class AppRunner:
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=query or "",
query=query if query else '',
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
trace_manager=app_generate_entity.trace_manager
)
def check_hosting_moderation(
self,
application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage],
) -> bool:
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage]) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
@ -371,7 +384,8 @@ class AppRunner:
"""
hosting_moderation_feature = HostingModerationFeature()
moderation_result = hosting_moderation_feature.check(
application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
application_generate_entity=application_generate_entity,
prompt_messages=prompt_messages
)
if moderation_result:
@ -379,20 +393,18 @@ class AppRunner:
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream,
text="I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream
)
return moderation_result
def fill_in_inputs_from_external_data_tools(
self,
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str,
) -> dict:
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
@ -405,12 +417,18 @@ class AppRunner:
"""
external_data_fetch_feature = ExternalDataFetch()
return external_data_fetch_feature.fetch(
tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
tenant_id=tenant_id,
app_id=app_id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
def query_app_annotations_to_reply(self, app_record: App,
message: Message,
query: str,
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
@ -422,5 +440,9 @@ class AppRunner:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
app_record=app_record,
message=message,
query=query,
user_id=user_id,
invoke_from=invoke_from
)

View File

@ -22,19 +22,15 @@ class ChatAppConfig(EasyUIBasedAppConfig):
"""
Chatbot App Config Entity.
"""
pass
class ChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> ChatAppConfig:
def get_app_config(cls, app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> ChatAppConfig:
"""
Convert app model config to chat app config
:param app_model: app model
@ -55,7 +51,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
config_dict = app_model_config_dict.copy()
else:
if not override_config_dict:
raise Exception("override_config_dict is required when config_from is ARGS")
raise Exception('override_config_dict is required when config_from is ARGS')
config_dict = override_config_dict
@ -67,11 +63,19 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -109,9 +113,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
related_config_keys.extend(current_related_config_keys)
# opening_statement
@ -120,8 +123,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config
)
config)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@ -137,9 +139,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))

View File

@ -10,7 +10,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
@ -30,8 +30,7 @@ logger = logging.getLogger(__name__)
class ChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
@ -40,8 +39,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
app_model: App,
self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
@ -49,8 +47,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self,
app_model: App,
self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
@ -65,46 +62,58 @@ class ChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get("query"):
raise ValueError("query is required")
if not args.get('query'):
raise ValueError('query is required')
query = args["query"]
query = args['query']
if not isinstance(query, str):
raise ValueError("query must be a string")
raise ValueError('query must be a string')
query = query.replace("\x00", "")
inputs = args["inputs"]
query = query.replace('\x00', '')
inputs = args['inputs']
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
# get conversation
conversation = None
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
# validate override model config
override_model_config_dict = None
if args.get("model_config"):
if args.get('model_config'):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError("Only in App debug mode can override model config")
raise ValueError('Only in App debug mode can override model config')
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config")
tenant_id=app_model.tenant_id,
config=args.get('model_config')
)
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True}
override_model_config_dict["retriever_resource"] = {
"enabled": True
}
# parse files
files = args["files"] if args.get("files") else []
files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
else:
file_objs = []
@ -113,7 +122,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict,
override_config_dict=override_model_config_dict
)
# get tracing instance
@ -132,11 +141,14 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
trace_manager=trace_manager
)
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -145,20 +157,17 @@ class ChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id,
message_id=message.id
)
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread.start()
@ -172,16 +181,16 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
return ChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
def _generate_worker(self, flask_app: Flask,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@ -203,19 +212,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
message=message
)
except GenerateTaskStoppedError:
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

Some files were not shown because too many files have changed in this diff Show More