mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 12:09:27 +08:00
Compare commits
239 Commits
0.7.2
...
feat/exter
| Author | SHA1 | Date | |
|---|---|---|---|
| 37f7d5732a | |||
| dcb033d221 | |||
| 9f894bb3b3 | |||
| cf645c3ba1 | |||
| e896d1e9d7 | |||
| 6dba68f62d | |||
| 3d083b758f | |||
| aa5b2db10a | |||
| b73faae0d0 | |||
| 4788e1c8c8 | |||
| bf16de50fe | |||
| 7e611ffbf3 | |||
| 65162a87b6 | |||
| 445497cf89 | |||
| fa1af8e47b | |||
| 624331472a | |||
| 72b7f8a949 | |||
| 88c9834ef2 | |||
| d882348f39 | |||
| b6ad7a1e06 | |||
| 6f7625fa47 | |||
| de7bc22649 | |||
| 52857dc0a6 | |||
| 032dd93b2f | |||
| 5b18e851d2 | |||
| f01602b570 | |||
| 0123498452 | |||
| f55e06d8bf | |||
| b613b11422 | |||
| 8efae1cba2 | |||
| bf55b1910f | |||
| 71b4480c4a | |||
| b6b1057a18 | |||
| 5b98acde2f | |||
| aad6f340b3 | |||
| a1104ab97e | |||
| 1ab81b4972 | |||
| 06b66216d7 | |||
| cd3eaed335 | |||
| 9d80d7def7 | |||
| 84ac5ccc8f | |||
| 5dfd7abb2b | |||
| 24af4b9313 | |||
| 6613b8f2e0 | |||
| 08c486452f | |||
| a45ac6ab98 | |||
| 80a322aaa2 | |||
| 82f7875a52 | |||
| 4637ddaa7f | |||
| 8d2269f762 | |||
| 5f03e66489 | |||
| a9c1f1a041 | |||
| 49cee773c5 | |||
| 89e81873c4 | |||
| c78828ab7c | |||
| e90d3c29ab | |||
| 153807f243 | |||
| 5db0b56c5b | |||
| 404db1ae5b | |||
| 02c4b1af71 | |||
| aa11659062 | |||
| d4985fb3aa | |||
| 8815511ccb | |||
| 40fb4d16ef | |||
| c69f5b07ba | |||
| 56c90e212a | |||
| 0f14873255 | |||
| 0bb7569d46 | |||
| ec57922bb6 | |||
| 781d294f49 | |||
| f515af2232 | |||
| fe8191b899 | |||
| 4d2cd6703b | |||
| 9ca0e56a8a | |||
| 292220c596 | |||
| 53f37a6704 | |||
| 75c1a82556 | |||
| c5b3777d93 | |||
| 678bbf8fe8 | |||
| 342607f4a4 | |||
| 5f4cdd66fa | |||
| 91942e37ff | |||
| 60913970dc | |||
| 82c42b9ec5 | |||
| 2a3d8c25bc | |||
| cee0c51dbb | |||
| fdbbdb706f | |||
| f6dfe23cf8 | |||
| ffd4bf8bf0 | |||
| bb3002b173 | |||
| d4dc54447a | |||
| d109881410 | |||
| d1605952b0 | |||
| 2cf1187b32 | |||
| 178730266d | |||
| dabfd74622 | |||
| 5da0182800 | |||
| ed37439ef7 | |||
| af92f19291 | |||
| 86f7f245e4 | |||
| 2d690801d1 | |||
| fede54be77 | |||
| 85ff82a694 | |||
| c8df92d0eb | |||
| 144d30d7ef | |||
| 4313d92e6b | |||
| 0695543f63 | |||
| 0bec6a037c | |||
| 3ff9a1f24a | |||
| a771eea4f6 | |||
| e7c77d961b | |||
| 61a0ca9e0d | |||
| 551b33c8e5 | |||
| fa34b9aed6 | |||
| bbb609179f | |||
| a27d4d58ec | |||
| 50d92f0fd4 | |||
| a15791e788 | |||
| 954580a4af | |||
| ab7d79275e | |||
| d3658166fb | |||
| 54b72bdd0a | |||
| d28446301f | |||
| 9050f92e5b | |||
| feefeb44d7 | |||
| d542b15cc0 | |||
| 2d7954c7da | |||
| b1918dae5e | |||
| 031a0b576d | |||
| 0cef25ef8c | |||
| cdb08be951 | |||
| 900fd82a92 | |||
| 44f963f281 | |||
| 01858e1caf | |||
| 2060db8e11 | |||
| 9ded063417 | |||
| d72da2777c | |||
| 89aede80cc | |||
| e0d3cd91c6 | |||
| 1a054ac1f4 | |||
| 3230f4a0ec | |||
| dadca0f91a | |||
| d489b8b3e0 | |||
| bd0992275c | |||
| 3e7597f2bd | |||
| 0e71f6db84 | |||
| f6b9982c23 | |||
| fb113a9479 | |||
| 15791510c8 | |||
| 0f72a8e89d | |||
| 14af87527f | |||
| 83e84865be | |||
| c2a3c5a748 | |||
| 83494cb4f5 | |||
| 0bc19c3fbf | |||
| 571415d1a4 | |||
| 7b2cf8215f | |||
| fee4d3f6ca | |||
| 161cc0cda9 | |||
| 71bff9fcf3 | |||
| 80d14c9b22 | |||
| c5bdf08558 | |||
| 596f160a1e | |||
| d8b6c053a2 | |||
| 4b262cae58 | |||
| 1a5116cba0 | |||
| 01581dd35f | |||
| 7fdd964379 | |||
| 0cfcc97e9d | |||
| 8986be0aab | |||
| f76bbbf5e6 | |||
| fe217da05c | |||
| 80aa7c4019 | |||
| 6f33351eb3 | |||
| 35f13c7327 | |||
| a8b9e01b3e | |||
| 7193e189f3 | |||
| 3f2a806abe | |||
| 5e4907e940 | |||
| bf63c5d1e3 | |||
| 78989e9049 | |||
| 1510bdbcf6 | |||
| 024d688b77 | |||
| ef82a29e23 | |||
| 1f56a20b62 | |||
| 0c2a62f847 | |||
| ea748b50f2 | |||
| 62bfc4dba6 | |||
| ceb2b150ff | |||
| dc015c380a | |||
| c9e0f0bf20 | |||
| bd6d4d0553 | |||
| f0273f00e1 | |||
| 962cdbbebd | |||
| 2c51e3a327 | |||
| 8e311cc45c | |||
| c441bea4d1 | |||
| ad30668eb6 | |||
| 62f4801523 | |||
| ec1408346e | |||
| 0e0a703496 | |||
| 54b693d5b1 | |||
| 1262277714 | |||
| 3a67fc6c5a | |||
| 26abbe8e5b | |||
| 5d0914daea | |||
| 7541a492b7 | |||
| 3a071b8db9 | |||
| 9342b4b951 | |||
| 4682e0ac7c | |||
| 7cfebffbb8 | |||
| 693fe912f2 | |||
| bc3a8e0ca2 | |||
| e38334cfd2 | |||
| 92cab33b73 | |||
| 3f467613fc | |||
| 205d33a813 | |||
| da326baa5e | |||
| d9198b5646 | |||
| 60001a62c4 | |||
| ee7d5e7206 | |||
| 2726fb3d5d | |||
| d7aa4076c9 | |||
| 122ce41020 | |||
| e7afee1176 | |||
| 88730906ec | |||
| a15080a1d7 | |||
| 35431bce0d | |||
| 7b7576ad55 | |||
| 162faee4f2 | |||
| b7ff98d7ff | |||
| a63e15081f | |||
| 0724640bbb | |||
| cb70e12827 | |||
| 067b956b2c | |||
| e7762b731c | |||
| f6c8390b0b | |||
| 4fd57929df | |||
| 517cdb2ca4 |
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: api/**
|
||||
|
||||
@ -66,7 +66,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
@ -97,7 +97,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
@ -107,7 +107,7 @@ jobs:
|
||||
dev/**
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v6
|
||||
uses: super-linter/super-linter/slim@v7
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
||||
54
.github/workflows/translate-i18n-base-on-english.yml
vendored
Normal file
54
.github/workflows/translate-i18n-base-on-english.yml
vendored
Normal file
@ -0,0 +1,54 @@
|
||||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
recent_commit_sha=$(git rev-parse HEAD)
|
||||
second_recent_commit_sha=$(git rev-parse HEAD~1)
|
||||
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@ -153,6 +153,9 @@ docker-legacy/volumes/etcd/*
|
||||
docker-legacy/volumes/minio/*
|
||||
docker-legacy/volumes/milvus/*
|
||||
docker-legacy/volumes/chroma/*
|
||||
docker-legacy/volumes/opensearch/data/*
|
||||
docker-legacy/volumes/pgvectors/data/*
|
||||
docker-legacy/volumes/pgvector/data/*
|
||||
|
||||
docker/volumes/app/storage/*
|
||||
docker/volumes/certbot/*
|
||||
@ -164,6 +167,12 @@ docker/volumes/etcd/*
|
||||
docker/volumes/minio/*
|
||||
docker/volumes/milvus/*
|
||||
docker/volumes/chroma/*
|
||||
docker/volumes/opensearch/data/*
|
||||
docker/volumes/myscale/data/*
|
||||
docker/volumes/myscale/log/*
|
||||
docker/volumes/unstructured/*
|
||||
docker/volumes/pgvector/data/*
|
||||
docker/volumes/pgvecto_rs/data/*
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
docker/middleware.env
|
||||
|
||||
@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
||||
|
||||
## Before you jump in
|
||||
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
|
||||
### Feature requests:
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
|
||||
## 在开始之前
|
||||
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
|
||||
### 功能请求:
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
|
||||
|
||||
## 飛び込む前に
|
||||
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
|
||||
### 機能リクエスト
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [
|
||||
|
||||
## Trước khi bắt đầu
|
||||
|
||||
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
|
||||
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
|
||||
|
||||
### Yêu cầu tính năng:
|
||||
|
||||
|
||||
2
LICENSE
2
LICENSE
@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
|
||||
|
||||
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
|
||||
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
|
||||
|
||||
@ -39,7 +39,7 @@ DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3, azure-blob, google-storage
|
||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
@ -60,7 +60,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
|
||||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
|
||||
@ -72,6 +73,12 @@ TENCENT_COS_SECRET_ID=your-secret-id
|
||||
TENCENT_COS_REGION=your-region
|
||||
TENCENT_COS_SCHEME=your-scheme
|
||||
|
||||
# Huawei OBS Storage Configuration
|
||||
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||
HUAWEI_OBS_SERVER=your-server-url
|
||||
|
||||
# OCI Storage configuration
|
||||
OCI_ENDPOINT=your-endpoint
|
||||
OCI_BUCKET_NAME=your-bucket-name
|
||||
@ -79,6 +86,13 @@ OCI_ACCESS_KEY=your-access-key
|
||||
OCI_SECRET_KEY=your-secret-key
|
||||
OCI_REGION=your-region
|
||||
|
||||
# Volcengine tos Storage configuration
|
||||
VOLCENGINE_TOS_ENDPOINT=your-endpoint
|
||||
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
|
||||
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
|
||||
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
|
||||
VOLCENGINE_TOS_REGION=your-region
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
@ -100,11 +114,10 @@ QDRANT_GRPC_ENABLED=false
|
||||
QDRANT_GRPC_PORT=6334
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_HOST=127.0.0.1
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_URI=http://127.0.0.1:19530
|
||||
MILVUS_TOKEN=
|
||||
MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# MyScale configuration
|
||||
MYSCALE_HOST=127.0.0.1
|
||||
|
||||
@ -55,7 +55,7 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@ -164,7 +164,7 @@ def initialize_extensions(app):
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint not in ["console", "inner_api"]:
|
||||
if request.blueprint not in {"console", "inner_api"}:
|
||||
return None
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
||||
@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm):
|
||||
)
|
||||
@click.confirmation_option(
|
||||
prompt=click.style(
|
||||
"Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
|
||||
"Are you sure you want to reset encrypt key pair? this operation cannot be rolled back!", fg="red"
|
||||
)
|
||||
)
|
||||
def reset_encrypt_key_pair():
|
||||
@ -131,7 +131,7 @@ def reset_encrypt_key_pair():
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
"Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
|
||||
"Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
|
||||
@click.command("vdb-migrate", help="migrate vector db.")
|
||||
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
|
||||
def vdb_migrate(scope: str):
|
||||
if scope in ["knowledge", "all"]:
|
||||
if scope in {"knowledge", "all"}:
|
||||
migrate_knowledge_vector_database()
|
||||
if scope in ["annotation", "all"]:
|
||||
if scope in {"annotation", "all"}:
|
||||
migrate_annotation_vector_database()
|
||||
|
||||
|
||||
@ -275,8 +275,7 @@ def migrate_knowledge_vector_database():
|
||||
for dataset in datasets:
|
||||
total_count = total_count + 1
|
||||
click.echo(
|
||||
f"Processing the {total_count} dataset {dataset.id}. "
|
||||
+ f"{create_count} created, {skipped_count} skipped."
|
||||
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
|
||||
)
|
||||
try:
|
||||
click.echo("Create dataset vdb index: {}".format(dataset.id))
|
||||
@ -411,7 +410,8 @@ def migrate_knowledge_vector_database():
|
||||
try:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
|
||||
f"Start to created vector index with {len(documents)} documents of {segments_count}"
|
||||
f" segments for dataset {dataset.id}.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@ -559,8 +559,9 @@ def add_qdrant_doc_id_index(field: str):
|
||||
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
@click.option("--email", prompt=True, help="The email address of the tenant account.")
|
||||
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
|
||||
@click.option("--language", prompt=True, help="Account language, default: en-US.")
|
||||
def create_tenant(email: str, language: Optional[str] = None):
|
||||
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
|
||||
"""
|
||||
Create tenant account
|
||||
"""
|
||||
@ -580,17 +581,19 @@ def create_tenant(email: str, language: Optional[str] = None):
|
||||
if language not in languages:
|
||||
language = "en-US"
|
||||
|
||||
name = name.strip()
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
"Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
|
||||
"Congratulations! Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.feature.hosted_service import HostedServiceConfig
|
||||
@ -45,8 +45,8 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
Code Execution Sandbox configs
|
||||
"""
|
||||
|
||||
CODE_EXECUTION_ENDPOINT: str = Field(
|
||||
description="endpoint URL of code execution servcie",
|
||||
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
|
||||
description="endpoint URL of code execution service",
|
||||
default="http://sandbox:8194",
|
||||
)
|
||||
|
||||
@ -55,6 +55,21 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
default="dify-sandbox",
|
||||
)
|
||||
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
|
||||
description="connect timeout in seconds for code execution request",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
|
||||
description="read timeout in seconds for code execution request",
|
||||
default=60.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
|
||||
description="write timeout in seconds for code execution request",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
default=9223372036854775807,
|
||||
@ -114,12 +129,12 @@ class EndpointConfig(BaseSettings):
|
||||
)
|
||||
|
||||
SERVICE_API_URL: str = Field(
|
||||
description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
|
||||
description="Service API Url prefix. used to display Service API Base Url to the front-end.",
|
||||
default="",
|
||||
)
|
||||
|
||||
APP_WEB_URL: str = Field(
|
||||
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
|
||||
description="WebApp Url prefix. used to display WebAPP API Base Url to the front-end.",
|
||||
default="",
|
||||
)
|
||||
|
||||
@ -202,20 +217,17 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=300,
|
||||
)
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
|
||||
] = 10
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
|
||||
] = 60
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
|
||||
] = 20
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
@ -260,7 +272,7 @@ class LoggingConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
LOG_LEVEL: str = Field(
|
||||
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
|
||||
description="Log output level, default to INFO. It is recommended to set it to ERROR for production.",
|
||||
default="INFO",
|
||||
)
|
||||
|
||||
@ -403,7 +415,7 @@ class MailConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
|
||||
description="Mail provider type name, default to None, available values are `smtp` and `resend`.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.middleware.cache.redis_config import RedisConfig
|
||||
@ -9,10 +9,13 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
|
||||
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
|
||||
@ -156,6 +159,21 @@ class CeleryConfig(DatabaseConfig):
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
default=False,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel master name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
@ -183,6 +201,8 @@ class MiddlewareConfig(
|
||||
AzureBlobStorageConfig,
|
||||
GoogleCloudStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
VolcengineTOSStorageConfig,
|
||||
S3StorageConfig,
|
||||
OCIStorageConfig,
|
||||
# configs of vdb and vdb providers
|
||||
@ -200,5 +220,6 @@ class MiddlewareConfig(
|
||||
TencentVectorDBConfig,
|
||||
TiDBVectorConfig,
|
||||
WeaviateConfig,
|
||||
ElasticsearchConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
32
api/configs/middleware/cache/redis_config.py
vendored
32
api/configs/middleware/cache/redis_config.py
vendored
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -38,3 +38,33 @@ class RedisConfig(BaseSettings):
|
||||
description="whether to use SSL for Redis connection",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_SENTINELS: Optional[str] = Field(
|
||||
description="Redis Sentinel nodes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel service name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
|
||||
description="Redis Sentinel username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
|
||||
description="Redis Sentinel password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
@ -38,3 +38,8 @@ class AliyunOSSStorageConfig(BaseSettings):
|
||||
description="Aliyun OSS authentication version",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_PATH: Optional[str] = Field(
|
||||
description="Aliyun OSS path",
|
||||
default=None,
|
||||
)
|
||||
|
||||
29
api/configs/middleware/storage/huawei_obs_storage_config.py
Normal file
29
api/configs/middleware/storage/huawei_obs_storage_config.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HuaweiCloudOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Huawei Cloud OBS storage configs
|
||||
"""
|
||||
|
||||
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SERVER: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS server URL",
|
||||
default=None,
|
||||
)
|
||||
@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseModel):
|
||||
"""
|
||||
Volcengine tos storage configs
|
||||
"""
|
||||
|
||||
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Volcengine TOS Bucket Name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Access Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Secret Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
|
||||
description="Volcengine TOS Endpoint URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_REGION: Optional[str] = Field(
|
||||
description="Volcengine TOS Region",
|
||||
default=None,
|
||||
)
|
||||
30
api/configs/middleware/vdb/elasticsearch_config.py
Normal file
30
api/configs/middleware/vdb/elasticsearch_config.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ElasticsearchConfig(BaseSettings):
|
||||
"""
|
||||
Elasticsearch configs
|
||||
"""
|
||||
|
||||
ELASTICSEARCH_HOST: Optional[str] = Field(
|
||||
description="Elasticsearch host",
|
||||
default="127.0.0.1",
|
||||
)
|
||||
|
||||
ELASTICSEARCH_PORT: PositiveInt = Field(
|
||||
description="Elasticsearch port",
|
||||
default=9200,
|
||||
)
|
||||
|
||||
ELASTICSEARCH_USERNAME: Optional[str] = Field(
|
||||
description="Elasticsearch username",
|
||||
default="elastic",
|
||||
)
|
||||
|
||||
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
|
||||
description="Elasticsearch password",
|
||||
default="elastic",
|
||||
)
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -9,14 +9,14 @@ class MilvusConfig(BaseSettings):
|
||||
Milvus configs
|
||||
"""
|
||||
|
||||
MILVUS_HOST: Optional[str] = Field(
|
||||
description="Milvus host",
|
||||
default=None,
|
||||
MILVUS_URI: Optional[str] = Field(
|
||||
description="Milvus uri",
|
||||
default="http://127.0.0.1:19530",
|
||||
)
|
||||
|
||||
MILVUS_PORT: PositiveInt = Field(
|
||||
description="Milvus RestFul API port",
|
||||
default=9091,
|
||||
MILVUS_TOKEN: Optional[str] = Field(
|
||||
description="Milvus token",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_USER: Optional[str] = Field(
|
||||
@ -29,11 +29,6 @@ class MilvusConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description="whether to use SSL connection for Milvus",
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_DATABASE: str = Field(
|
||||
description="Milvus database, default to `default`",
|
||||
default="default",
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.7.2",
|
||||
default="0.8.2",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -37,7 +37,7 @@ from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_p
|
||||
from .billing import billing
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, external, file, hit_testing, website
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
|
||||
@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource):
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = args["desc"] if args["desc"] else ""
|
||||
copy_right = args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
|
||||
desc = args["desc"] or ""
|
||||
copy_right = args["copyright"] or ""
|
||||
privacy_policy = args["privacy_policy"] or ""
|
||||
custom_disclaimer = args["custom_disclaimer"] or ""
|
||||
else:
|
||||
desc = site.description if site.description else args["desc"] if args["desc"] else ""
|
||||
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
|
||||
privacy_policy = (
|
||||
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
|
||||
)
|
||||
custom_disclaimer = (
|
||||
site.custom_disclaimer
|
||||
if site.custom_disclaimer
|
||||
else args["custom_disclaimer"]
|
||||
if args["custom_disclaimer"]
|
||||
else ""
|
||||
)
|
||||
desc = site.description or args["desc"] or ""
|
||||
copy_right = site.copyright or args["copyright"] or ""
|
||||
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
|
||||
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
|
||||
|
||||
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
|
||||
|
||||
|
||||
@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
|
||||
def post(self, resource_id):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_admin_or_owner:
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
|
||||
@ -174,6 +174,7 @@ class AppApi(Resource):
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
@ -94,19 +94,15 @@ class ChatMessageTextApi(Resource):
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||
|
||||
@ -20,7 +20,7 @@ from fields.conversation_fields import (
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
||||
|
||||
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
@ -201,7 +201,11 @@ class ChatConversationApi(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
@ -210,7 +214,11 @@ class ChatConversationApi(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||
|
||||
@ -32,6 +32,8 @@ class ModelConfigResource(Resource):
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@ -32,6 +34,7 @@ def parse_app_site_args():
|
||||
)
|
||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -66,11 +69,14 @@ class AppSite(Resource):
|
||||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
@ -93,6 +99,8 @@ class AppSiteAccessTokenReset(Resource):
|
||||
raise NotFound
|
||||
|
||||
site.code = Site.generate_code(16)
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
||||
@ -11,12 +11,12 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class DailyConversationStatistic(Resource):
|
||||
class DailyMessageStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -25,14 +25,17 @@ class DailyConversationStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(*) AS message_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -45,7 +48,7 @@ class DailyConversationStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -55,10 +58,67 @@ class DailyConversationStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(db.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "message_count": i.message_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
class DailyConversationStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT messages.conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -79,14 +139,17 @@ class DailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -99,7 +162,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -109,10 +172,10 @@ class DailyTerminalsStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -133,16 +196,18 @@ class DailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
(sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
|
||||
sum(total_price) as total_price
|
||||
FROM messages where app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
|
||||
SUM(total_price) AS total_price
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -155,7 +220,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -165,10 +230,10 @@ class DailyTokenCostStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -191,16 +256,26 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||
FROM conversations c
|
||||
JOIN messages m ON c.id = m.conversation_id
|
||||
WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
m.conversation_id,
|
||||
COUNT(m.id) AS message_count
|
||||
FROM
|
||||
conversations c
|
||||
JOIN
|
||||
messages m
|
||||
ON c.id = m.conversation_id
|
||||
WHERE
|
||||
c.override_model_configs IS NULL
|
||||
AND c.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -213,7 +288,7 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and c.created_at >= :start"
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -223,14 +298,19 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and c.created_at < :end"
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += """
|
||||
GROUP BY m.conversation_id) subquery
|
||||
LEFT JOIN conversations c on c.id=subquery.conversation_id
|
||||
GROUP BY date
|
||||
ORDER BY date"""
|
||||
GROUP BY m.conversation_id
|
||||
) subquery
|
||||
LEFT JOIN
|
||||
conversations c
|
||||
ON c.id = subquery.conversation_id
|
||||
GROUP BY
|
||||
date
|
||||
ORDER BY
|
||||
date"""
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -253,17 +333,21 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
|
||||
FROM messages m
|
||||
LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
|
||||
WHERE m.app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) AS message_count,
|
||||
COUNT(mf.id) AS feedback_count
|
||||
FROM
|
||||
messages m
|
||||
LEFT JOIN
|
||||
message_feedbacks mf
|
||||
ON mf.message_id=m.id AND mf.rating='like'
|
||||
WHERE
|
||||
m.app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -276,7 +360,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and m.created_at >= :start"
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -286,10 +370,10 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and m.created_at < :end"
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -315,16 +399,17 @@ class AverageResponseTimeStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) as latency
|
||||
FROM messages
|
||||
WHERE app_id = :app_id
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) AS latency
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -337,7 +422,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -347,10 +432,10 @@ class AverageResponseTimeStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -371,17 +456,20 @@ class TokensPerSecondStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
CASE
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
END as tokens_per_second
|
||||
FROM messages
|
||||
WHERE app_id = :app_id"""
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
|
||||
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
@ -394,7 +482,7 @@ WHERE app_id = :app_id"""
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -404,10 +492,10 @@ WHERE app_id = :app_id"""
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -419,6 +507,7 @@ WHERE app_id = :app_id"""
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
api.add_resource(DailyMessageStatistic, "/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users")
|
||||
api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs")
|
||||
|
||||
@ -465,6 +465,6 @@ api.add_resource(
|
||||
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
|
||||
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
|
||||
)
|
||||
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
|
||||
|
||||
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
@ -26,16 +26,18 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(id) AS runs
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
@ -52,7 +54,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -62,10 +64,10 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -86,16 +88,18 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
@ -112,7 +116,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -122,10 +126,10 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -146,18 +150,18 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) as token_count
|
||||
FROM workflow_runs
|
||||
WHERE app_id = :app_id
|
||||
AND triggered_from = :triggered_from
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) AS token_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
@ -174,7 +178,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at >= :start"
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
@ -184,10 +188,10 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " and created_at < :end"
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date order by date"
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
@ -213,27 +217,31 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
AVG(sub.interactions) as interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(SELECT
|
||||
date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM workflow_runs c
|
||||
WHERE c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY date, c.created_by) sub
|
||||
GROUP BY sub.date
|
||||
"""
|
||||
sql_query = """SELECT
|
||||
AVG(sub.interactions) AS interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM
|
||||
workflow_runs c
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY
|
||||
date, c.created_by
|
||||
) sub
|
||||
GROUP BY
|
||||
sub.date"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
@ -262,7 +270,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
|
||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
|
||||
@ -8,7 +8,7 @@ from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len, timezone
|
||||
from libs.helper import StrLen, email, timezone
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
@ -37,7 +37,7 @@ class ActivateApi(Resource):
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
|
||||
@ -71,7 +71,7 @@ class OAuthCallback(Resource):
|
||||
|
||||
account = _generate_account(provider, user_info)
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
|
||||
return {"error": "Account is banned or closed."}, 403
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
|
||||
if not account:
|
||||
# Create account
|
||||
account_name = user_info.name if user_info.name else "Dify"
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
)
|
||||
|
||||
@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
@ -110,6 +110,26 @@ class DatasetListApi(Resource):
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_api_template_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
choices=Dataset.PROVIDER_LIST,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
@ -122,6 +142,10 @@ class DatasetListApi(Resource):
|
||||
name=args["name"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
provider=args["provider"],
|
||||
external_api_template_id=args["external_api_template_id"],
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
@ -398,7 +422,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -549,12 +573,7 @@ class DatasetApiBaseUrlApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
return {
|
||||
"api_base_url": (
|
||||
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
|
||||
)
|
||||
+ "/v1"
|
||||
}
|
||||
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
|
||||
|
||||
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
|
||||
@ -302,6 +302,8 @@ class DatasetInitApi(Resource):
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
@ -309,6 +311,8 @@ class DatasetInitApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
if args["indexing_technique"] == "high_quality":
|
||||
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_default_model_instance(
|
||||
@ -350,7 +354,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in ["completed", "error"]:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
|
||||
data_process_rule = document.dataset_process_rule
|
||||
@ -417,7 +421,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
info_list = []
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in ["completed", "error"]:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
# format document files info
|
||||
@ -599,6 +603,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
else:
|
||||
process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@ -631,6 +636,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
@ -659,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
|
||||
db.session.commit()
|
||||
|
||||
elif action == "resume":
|
||||
if document.indexing_status not in ["paused", "error"]:
|
||||
if document.indexing_status not in {"paused", "error"}:
|
||||
raise InvalidActionError("Document not in paused or error state.")
|
||||
|
||||
document.paused_by = None
|
||||
|
||||
254
api/controllers/console/datasets/external.py
Normal file
254
api/controllers/console/datasets/external.py
Normal file
@ -0,0 +1,254 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class ExternalApiTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
|
||||
api_templates, total = ExternalDatasetService.get_external_api_templates(
|
||||
page, limit, current_user.current_tenant_id, search
|
||||
)
|
||||
response = {
|
||||
"data": [item.to_dict() for item in api_templates],
|
||||
"has_more": len(api_templates) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Description is required. Description must be between 1 to 400 characters.",
|
||||
type=_validate_description_length,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
api_template = ExternalDatasetService.create_api_template(
|
||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return api_template.to_dict(), 201
|
||||
|
||||
|
||||
class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, api_template_id):
|
||||
api_template_id = str(api_template_id)
|
||||
api_template = ExternalDatasetService.get_api_template(api_template_id)
|
||||
if api_template is None:
|
||||
raise NotFound("API template not found.")
|
||||
|
||||
return api_template.to_dict(), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, api_template_id):
|
||||
api_template_id = str(api_template_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="description is required. Description must be between 1 to 400 characters.",
|
||||
type=_validate_description_length,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
|
||||
api_template = ExternalDatasetService.update_api_template(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
api_template_id=api_template_id,
|
||||
args=args,
|
||||
)
|
||||
|
||||
return api_template.to_dict(), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_template_id):
|
||||
api_template_id = str(api_template_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor or current_user.is_dataset_operator:
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_api_template(current_user.current_tenant_id, api_template_id)
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ExternalApiUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, api_template_id):
|
||||
api_template_id = str(api_template_id)
|
||||
|
||||
external_api_template_is_using = ExternalDatasetService.external_api_template_use_check(api_template_id)
|
||||
return {"is_using": external_api_template_is_using}, 200
|
||||
|
||||
|
||||
class ExternalDatasetInitApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("api_template_id", type=str, required=True, nullable=True, location="json")
|
||||
# parser.add_argument('name', nullable=False, required=True,
|
||||
# help='name is required. Name must be between 1 to 100 characters.',
|
||||
# type=_validate_name)
|
||||
# parser.add_argument('description', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_parameter", type=dict, required=True, nullable=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# validate args
|
||||
ExternalDatasetService.document_create_args_validate(
|
||||
current_user.current_tenant_id, args["api_template_id"], args["process_parameter"]
|
||||
)
|
||||
|
||||
try:
|
||||
dataset, documents, batch = ExternalDatasetService.init_external_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
args=args,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
response = {"dataset": dataset, "documents": documents, "batch": batch}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ExternalDatasetCreateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("external_api_template_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", type=str, required=True, nullable=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
dataset = ExternalDatasetService.create_external_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
args=args,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-api-template")
|
||||
api.add_resource(ExternalApiTemplateApi, "/datasets/external-api-template/<uuid:api_template_id>")
|
||||
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-api-template/<uuid:api_template_id>/use-check")
|
||||
@ -39,7 +39,7 @@ class FileApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check(resource="documents")
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
@ -47,6 +47,7 @@ class HitTestingApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrival_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
@ -57,6 +58,7 @@ class HitTestingApi(Resource):
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrival_model"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
|
||||
49
api/controllers/console/datasets/test_external.py
Normal file
49
api/controllers/console/datasets/test_external.py
Normal file
@ -0,0 +1,49 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
class TestExternalApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"top_k",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=int,
|
||||
)
|
||||
parser.add_argument(
|
||||
"score_threshold",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=float,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
result = ExternalDatasetService.test_external_knowledge_retrival(
|
||||
args["top_k"], args["score_threshold"]
|
||||
)
|
||||
response = {
|
||||
"data": [item.to_dict() for item in api_templates],
|
||||
"has_more": len(api_templates) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
|
||||
|
||||
|
||||
api.add_resource(TestExternalApi, "/dify/external-knowledge/retrival-documents")
|
||||
@ -18,9 +18,7 @@ class NotSetupError(BaseHTTPException):
|
||||
|
||||
class NotInitValidateError(BaseHTTPException):
|
||||
error_code = "not_init_validated"
|
||||
description = (
|
||||
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
|
||||
)
|
||||
description = "Init validation has not been completed yet. Please proceed with the init validation process first."
|
||||
code = 401
|
||||
|
||||
|
||||
|
||||
@ -81,19 +81,15 @@ class ChatTextApi(InstalledAppResource):
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||
|
||||
@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
|
||||
def delete(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@ -31,10 +31,11 @@ class InstalledAppsListApi(Resource):
|
||||
"app_owner_tenant_id": installed_app.app_owner_tenant_id,
|
||||
"is_pinned": installed_app.is_pinned,
|
||||
"last_used_at": installed_app.last_used_at,
|
||||
"editable": current_user.role in ["owner", "admin"],
|
||||
"editable": current_user.role in {"owner", "admin"},
|
||||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
if installed_app.app is not None
|
||||
]
|
||||
installed_apps.sort(
|
||||
key=lambda app: (
|
||||
|
||||
@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
|
||||
app_model = installed_app.app
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
def get(self, installed_app, message_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
|
||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
@ -4,7 +4,7 @@ from flask import session
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import str_len
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
|
||||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("password", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("password", type=StrLen(30), required=True, location="json")
|
||||
input_password = parser.parse_args()["password"]
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
|
||||
@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import email, get_remote_ip, str_len
|
||||
from libs.helper import StrLen, email, get_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
@ -40,7 +40,7 @@ class SetupApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
|
||||
|
||||
@ -218,7 +218,7 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
|
||||
api.add_resource(
|
||||
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>"
|
||||
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>"
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
|
||||
@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
|
||||
|
||||
return ApiToolManageService.test_api_tool_preview(
|
||||
current_user.current_tenant_id,
|
||||
args["provider_name"] if args["provider_name"] else "",
|
||||
args["provider_name"] or "",
|
||||
args["tool_name"],
|
||||
args["credentials"],
|
||||
args["parameters"],
|
||||
|
||||
@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
raise TooManyFilesError()
|
||||
|
||||
extension = file.filename.split(".")[-1]
|
||||
if extension.lower() not in ["svg", "png"]:
|
||||
if extension.lower() not in {"svg", "png"}:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
|
||||
@ -46,9 +46,7 @@ def only_edition_self_hosted(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(
|
||||
resource: str, error_msg: str = "You have reached the limit of your subscription."
|
||||
):
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
@ -60,22 +58,23 @@ def cloud_edition_billing_resource_check(
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
||||
# The api of file upload is used in the multiple places,
|
||||
# so we need to check the source of the request from datasets
|
||||
source = request.args.get("source")
|
||||
if source == "datasets":
|
||||
abort(403, error_msg)
|
||||
abort(403, "The number of documents has reached the limit of your subscription.")
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
elif resource == "workspace_custom" and not features.can_replace_logo:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The workspace custom feature has reached the limit of your subscription.")
|
||||
elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The annotation quota has reached the limit of your subscription.")
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -86,10 +85,7 @@ def cloud_edition_billing_resource_check(
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(
|
||||
resource: str,
|
||||
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
):
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
@ -97,7 +93,10 @@ def cloud_edition_billing_knowledge_limit_check(
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
abort(403, error_msg)
|
||||
abort(
|
||||
403,
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ class AppParameterApi(Resource):
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
@ -79,19 +79,15 @@ class TextApi(Resource):
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(
|
||||
|
||||
@ -96,7 +96,7 @@ class ChatApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
@ -18,7 +18,7 @@ class ConversationApi(Resource):
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@ -76,7 +76,7 @@ class MessageListApi(Resource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
@ -22,10 +23,12 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -113,6 +116,30 @@ class WorkflowTaskStopApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@validate_app_token
|
||||
@marshal_with(workflow_app_log_pagination_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model, args=args
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
|
||||
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
|
||||
api.add_resource(WorkflowAppLogApi, "/workflows/logs")
|
||||
|
||||
@ -82,6 +82,26 @@ class DatasetListApi(DatasetApiResource):
|
||||
required=False,
|
||||
nullable=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_api_template_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="_validate_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -91,6 +111,9 @@ class DatasetListApi(DatasetApiResource):
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=args["permission"],
|
||||
provider=args["provider"],
|
||||
external_api_template_id=args["external_api_template_id"],
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
@ -36,6 +36,10 @@ class SegmentApi(DatasetApiResource):
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.indexing_status != "completed":
|
||||
raise NotFound("Document is not completed.")
|
||||
if not document.enabled:
|
||||
raise NotFound("Document is disabled.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
@ -63,7 +67,7 @@ class SegmentApi(DatasetApiResource):
|
||||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segemtns is required"}, 400
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
|
||||
@ -83,9 +83,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(
|
||||
resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription."
|
||||
):
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
@ -98,13 +96,13 @@ def cloud_edition_billing_resource_check(
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden("The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden("The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden("The number of documents has reached the limit of your subscription.")
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -115,11 +113,7 @@ def cloud_edition_billing_resource_check(
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
):
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
@ -128,7 +122,9 @@ def cloud_edition_billing_knowledge_limit_check(
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden(
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
|
||||
)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
@ -78,19 +78,15 @@ class TextApi(WebApiResource):
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = (
|
||||
args.get("voice")
|
||||
if args.get("voice")
|
||||
else app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
)
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
|
||||
class ChatApi(WebApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
|
||||
class ChatStopApi(WebApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
|
||||
|
||||
@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
|
||||
class ConversationApi(WebApiResource):
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
|
||||
class ConversationPinApi(WebApiResource):
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
|
||||
class ConversationUnPinApi(WebApiResource):
|
||||
def patch(self, app_model, end_user, c_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
class MessageSuggestedQuestionApi(WebApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
@ -39,6 +39,7 @@ class AppSiteApi(WebApiResource):
|
||||
"default_language": fields.String,
|
||||
"prompt_public": fields.Boolean,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
app_fields = {
|
||||
|
||||
@ -80,7 +80,8 @@ def _validate_web_sso_token(decoded, system_features, app_code):
|
||||
if not source or source != "sso":
|
||||
raise WebSSOAuthRequiredError()
|
||||
|
||||
# Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
|
||||
# Check if SSO is not enforced for web, and if the token source is SSO,
|
||||
# raise an error and redirect to normal passport login
|
||||
if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
|
||||
source = decoded.get("token_source")
|
||||
if source and source == "sso":
|
||||
|
||||
@ -1 +1 @@
|
||||
import core.moderation.base
|
||||
import core.moderation.base
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner):
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(
|
||||
prompt_messages=prompt_messages or []
|
||||
)
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner):
|
||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback
|
||||
hit_callback=hit_callback,
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
self.agent_thought_count = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
# check if model supports stream tool call
|
||||
@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner):
|
||||
self.query = None
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
||||
-> AgentChatAppGenerateEntity:
|
||||
def _repack_app_generate_entity(
|
||||
self, app_generate_entity: AgentChatAppGenerateEntity
|
||||
) -> AgentChatAppGenerateEntity:
|
||||
"""
|
||||
Repack app generate entity
|
||||
"""
|
||||
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
|
||||
|
||||
return app_generate_entity
|
||||
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
|
||||
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner):
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
for parameter in tool.get_runtime_parameters():
|
||||
parameter_type = 'string'
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
parameter_type = "string"
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.name not in prompt_tool.parameters["required"]:
|
||||
prompt_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
|
||||
"""
|
||||
Init tools
|
||||
"""
|
||||
@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner):
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.name not in prompt_tool.parameters["required"]:
|
||||
prompt_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
|
||||
def create_agent_thought(
|
||||
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
"""
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
thought="",
|
||||
tool=tool_name,
|
||||
tool_labels_str='{}',
|
||||
tool_meta_str='{}',
|
||||
tool_labels_str="{}",
|
||||
tool_meta_str="{}",
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_files=json.dumps(messages_ids) if messages_ids else '',
|
||||
answer='',
|
||||
observation='',
|
||||
message_files=json.dumps(messages_ids) if messages_ids else "",
|
||||
answer="",
|
||||
observation="",
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
position=self.agent_thought_count + 1,
|
||||
currency='USD',
|
||||
currency="USD",
|
||||
latency=0,
|
||||
created_by_role='account',
|
||||
created_by_role="account",
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
return thought
|
||||
|
||||
def save_agent_thought(self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None,
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.id == agent_thought.id
|
||||
).first()
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner):
|
||||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner):
|
||||
if tool_label:
|
||||
labels[tool] = tool_label.to_dict()
|
||||
else:
|
||||
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
).first()
|
||||
db_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner):
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
result.append(prompt_message)
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
messages: list[Message] = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
if message.id == self.message.id:
|
||||
@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner):
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tools = tools.split(";")
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
tool_inputs = { tool: {} for tool in tools }
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner):
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
)
|
||||
tool_call_response.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
)
|
||||
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
result.extend(
|
||||
[
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response,
|
||||
]
|
||||
)
|
||||
if not tools:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner):
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files,
|
||||
file_extra_config
|
||||
)
|
||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
||||
@ -25,17 +25,19 @@ from models.model import Message
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
|
||||
def run(self, message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_conf.stop:
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append('Observation')
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(
|
||||
instruction, inputs)
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
message_file_ids = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(
|
||||
chunks, usage_dict)
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip(
|
||||
) or 'I am thinking about how to help you'
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
if "usage" in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input
|
||||
} if scratchpad.action else {},
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
observation="",
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ''
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(
|
||||
scratchpad.action.action_input)
|
||||
final_answer = json.dumps(scratchpad.action.action_input)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought,
|
||||
observation={
|
||||
scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={
|
||||
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage']
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint=''
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
messages_ids=[],
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action['action'],
|
||||
action_input=action['action_input']
|
||||
)
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
format assistant message
|
||||
"""
|
||||
message = ''
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if not current_scratchpad:
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or 'I am thinking about how to help you',
|
||||
action_str='',
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(
|
||||
message.tool_calls[0].function.arguments)
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(
|
||||
current_scratchpad.action.to_dict()
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
current_scratchpad.observation = message.content
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
|
||||
@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt \
|
||||
.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content='')
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([
|
||||
system_message,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
])
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
|
||||
@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
|
||||
@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ''
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = system_prompt \
|
||||
.replace("{{historic_messages}}", historic_prompt) \
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt) \
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
|
||||
@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
"""
|
||||
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Action Entity.
|
||||
"""
|
||||
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel):
|
||||
Convert to dictionary.
|
||||
"""
|
||||
return {
|
||||
'action': self.action_name,
|
||||
'action_input': self.action_input,
|
||||
"action": self.action_name,
|
||||
"action_input": self.action_input,
|
||||
}
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel):
|
||||
Check if the scratchpad unit is final.
|
||||
"""
|
||||
return self.action is None or (
|
||||
'final' in self.action.action_name.lower() and
|
||||
'answer' in self.action.action_name.lower()
|
||||
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
||||
)
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
@ -67,8 +70,9 @@ class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
||||
FUNCTION_CALLING = 'function-calling'
|
||||
|
||||
CHAIN_OF_THOUGHT = "chain-of-thought"
|
||||
FUNCTION_CALLING = "function-calling"
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
@ -24,11 +24,9 @@ from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
def run(self,
|
||||
message: Message, query: str, **kwargs: Any
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
response = ""
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ''
|
||||
tool_call_inputs = ''
|
||||
tool_call_names = ""
|
||||
tool_call_inputs = ""
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
response += result.message.content
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ''
|
||||
result.message.content = ""
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=[]
|
||||
)
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls=[
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type='function',
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1],
|
||||
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
)
|
||||
) for tool_call in tool_calls
|
||||
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
llm_usage=current_llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict()
|
||||
"meta": tool_invoke_meta.to_dict(),
|
||||
}
|
||||
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
if tool_response['tool_response'] is not None:
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response['tool_response'],
|
||||
content=tool_response["tool_response"],
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
thought=None,
|
||||
tool_invoke_meta={
|
||||
tool_response['tool_call_name']: tool_response['meta']
|
||||
for tool_response in tool_responses
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response['tool_call_name']: tool_response['tool_response']
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
]
|
||||
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
content.data if content.type == PromptMessageContentType.TEXT else
|
||||
'[image]' if content.type == PromptMessageContentType.IMAGE else
|
||||
'[file]'
|
||||
for content in prompt_message.content
|
||||
])
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||
|
||||
@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [
|
||||
*self.history_prompt_messages,
|
||||
*query_prompt_messages,
|
||||
*self._current_thoughts
|
||||
]
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
|
||||
@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
|
||||
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(json_str):
|
||||
try:
|
||||
action = json.loads(json_str)
|
||||
@ -22,7 +23,7 @@ class CotAgentOutputParser:
|
||||
action = action[0]
|
||||
|
||||
for key, value in action.items():
|
||||
if 'input' in key.lower():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
@ -33,37 +34,37 @@ class CotAgentOutputParser:
|
||||
action_input=action_input,
|
||||
)
|
||||
else:
|
||||
return json_str or ''
|
||||
return json_str or ""
|
||||
except:
|
||||
return json_str or ''
|
||||
|
||||
return json_str or ""
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
yield parse_action(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
action_cache = ''
|
||||
action_str = 'action:'
|
||||
action_cache = ""
|
||||
action_str = "action:"
|
||||
action_idx = 0
|
||||
|
||||
thought_cache = ''
|
||||
thought_str = 'thought:'
|
||||
thought_cache = ""
|
||||
thought_str = "thought:"
|
||||
thought_idx = 0
|
||||
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict['usage'] = response.delta.usage
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
@ -72,24 +73,24 @@ class CotAgentOutputParser:
|
||||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
last_character = response[index-1] if index > 0 else ''
|
||||
delta = response[index : index + steps]
|
||||
last_character = response[index - 1] if index > 0 else ""
|
||||
|
||||
if delta == '`':
|
||||
if delta == "`":
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
code_block_cache = ""
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
@ -97,7 +98,7 @@ class CotAgentOutputParser:
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
@ -105,18 +106,18 @@ class CotAgentOutputParser:
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if action_cache:
|
||||
yield action_cache
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
|
||||
|
||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
@ -124,7 +125,7 @@ class CotAgentOutputParser:
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
@ -132,31 +133,31 @@ class CotAgentOutputParser:
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if thought_cache:
|
||||
yield thought_cache
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
if delta == "{":
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
elif delta == "}":
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
@ -172,12 +173,12 @@ class CotAgentOutputParser:
|
||||
if got_json:
|
||||
got_json = False
|
||||
yield parse_action(json_cache)
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
yield delta.replace("`", "")
|
||||
|
||||
index += steps
|
||||
|
||||
@ -186,4 +187,3 @@ class CotAgentOutputParser:
|
||||
|
||||
if json_cache:
|
||||
yield parse_action(json_cache)
|
||||
|
||||
|
||||
@ -41,7 +41,8 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
|
||||
{{historic_messages}}
|
||||
Question: {{query}}
|
||||
{{agent_scratchpad}}
|
||||
Thought:"""
|
||||
Thought:""" # noqa: E501
|
||||
|
||||
|
||||
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
|
||||
Thought:"""
|
||||
@ -86,19 +87,20 @@ Action:
|
||||
```
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
|
||||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||
|
||||
REACT_PROMPT_TEMPLATES = {
|
||||
'english': {
|
||||
'chat': {
|
||||
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
|
||||
"english": {
|
||||
"chat": {
|
||||
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
"completion": {
|
||||
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
'completion': {
|
||||
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,34 +26,24 @@ class BaseAppConfigManager:
|
||||
config_dict = dict(config_dict.items())
|
||||
|
||||
additional_features = AppAdditionalFeatures()
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.file_upload = FileUploadConfigManager.convert(
|
||||
config=config_dict,
|
||||
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
|
||||
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT}
|
||||
)
|
||||
|
||||
additional_features.opening_statement, additional_features.suggested_questions = \
|
||||
OpeningStatementConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.opening_statement, additional_features.suggested_questions = (
|
||||
OpeningStatementConfigManager.convert(config=config_dict)
|
||||
)
|
||||
|
||||
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
|
||||
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
|
||||
|
||||
return additional_features
|
||||
|
||||
@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory
|
||||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
|
||||
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
|
||||
if sensitive_word_avoidance_dict.get('enabled'):
|
||||
if sensitive_word_avoidance_dict.get("enabled"):
|
||||
return SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get('type'),
|
||||
config=sensitive_word_avoidance_dict.get('config'),
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config"),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
|
||||
-> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
typ = config["sensitive_word_avoidance"]["type"]
|
||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
|
||||
|
||||
ModerationFactory.validate_config(
|
||||
name=typ,
|
||||
tenant_id=tenant_id,
|
||||
config=sensitive_word_avoidance_config
|
||||
)
|
||||
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
||||
|
||||
return config, ["sensitive_word_avoidance"]
|
||||
|
||||
@ -12,67 +12,70 @@ class AgentConfigManager:
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
if 'agent_mode' in config and config['agent_mode'] \
|
||||
and 'enabled' in config['agent_mode']:
|
||||
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
agent_strategy = agent_dict.get("strategy", "cot")
|
||||
|
||||
agent_dict = config.get('agent_mode', {})
|
||||
agent_strategy = agent_dict.get('strategy', 'cot')
|
||||
|
||||
if agent_strategy == 'function_call':
|
||||
if agent_strategy == "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy == 'cot' or agent_strategy == 'react':
|
||||
elif agent_strategy in {"cot", "react"}:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if config['model']['provider'] == 'openai':
|
||||
if config["model"]["provider"] == "openai":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
'provider_type': tool['provider_type'],
|
||||
'provider_id': tool['provider_id'],
|
||||
'tool_name': tool['tool_name'],
|
||||
'tool_parameters': tool.get('tool_parameters', {})
|
||||
"provider_type": tool["provider_type"],
|
||||
"provider_id": tool["provider_id"],
|
||||
"tool_name": tool["tool_name"],
|
||||
"tool_parameters": tool.get("tool_parameters", {}),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
|
||||
if 'strategy' in config['agent_mode'] and \
|
||||
config['agent_mode']['strategy'] not in ['react_router', 'router']:
|
||||
agent_prompt = agent_dict.get('prompt', None) or {}
|
||||
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
|
||||
"react_router",
|
||||
"router",
|
||||
}:
|
||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||
# check model mode
|
||||
model_mode = config.get('model', {}).get('mode', 'completion')
|
||||
if model_mode == 'completion':
|
||||
model_mode = config.get("model", {}).get("mode", "completion")
|
||||
if model_mode == "completion":
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt',
|
||||
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration',
|
||||
REACT_PROMPT_TEMPLATES['english']['completion'][
|
||||
'agent_scratchpad']),
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
else:
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt',
|
||||
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration',
|
||||
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
|
||||
return AgentEntity(
|
||||
provider=config['model']['provider'],
|
||||
model=config['model']['name'],
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get('max_iteration', 5)
|
||||
max_iteration=agent_dict.get("max_iteration", 5),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -15,39 +15,38 @@ class DatasetConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
dataset_ids = []
|
||||
if 'datasets' in config.get('dataset_configs', {}):
|
||||
datasets = config.get('dataset_configs', {}).get('datasets', {
|
||||
'strategy': 'router',
|
||||
'datasets': []
|
||||
})
|
||||
if "datasets" in config.get("dataset_configs", {}):
|
||||
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
||||
|
||||
for dataset in datasets.get('datasets', []):
|
||||
for dataset in datasets.get("datasets", []):
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != 'dataset':
|
||||
if len(keys) == 0 or keys[0] != "dataset":
|
||||
continue
|
||||
|
||||
dataset = dataset['dataset']
|
||||
dataset = dataset["dataset"]
|
||||
|
||||
if 'enabled' not in dataset or not dataset['enabled']:
|
||||
if "enabled" not in dataset or not dataset["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = dataset.get('id', None)
|
||||
dataset_id = dataset.get("id", None)
|
||||
if dataset_id:
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if 'agent_mode' in config and config['agent_mode'] \
|
||||
and 'enabled' in config['agent_mode'] \
|
||||
and config['agent_mode']['enabled']:
|
||||
if (
|
||||
"agent_mode" in config
|
||||
and config["agent_mode"]
|
||||
and "enabled" in config["agent_mode"]
|
||||
and config["agent_mode"]["enabled"]
|
||||
):
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
|
||||
agent_dict = config.get('agent_mode', {})
|
||||
|
||||
for tool in agent_dict.get('tools', []):
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != 'dataset':
|
||||
if key != "dataset":
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
@ -55,30 +54,28 @@ class DatasetConfigManager:
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = tool_item['id']
|
||||
dataset_id = tool_item["id"]
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if len(dataset_ids) == 0:
|
||||
return None
|
||||
|
||||
# dataset configs
|
||||
if 'dataset_configs' in config and config.get('dataset_configs'):
|
||||
dataset_configs = config.get('dataset_configs')
|
||||
if "dataset_configs" in config and config.get("dataset_configs"):
|
||||
dataset_configs = config.get("dataset_configs")
|
||||
else:
|
||||
dataset_configs = {
|
||||
'retrieval_model': 'multiple'
|
||||
}
|
||||
query_variable = config.get('dataset_query_variable')
|
||||
dataset_configs = {"retrieval_model": "multiple"}
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
)
|
||||
)
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
return DatasetEntity(
|
||||
@ -86,15 +83,15 @@ class DatasetConfigManager:
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=dataset_configs.get('top_k', 4),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
|
||||
)
|
||||
top_k=dataset_configs.get("top_k", 4),
|
||||
score_threshold=dataset_configs.get("score_threshold"),
|
||||
reranking_model=dataset_configs.get("reranking_model"),
|
||||
weights=dataset_configs.get("weights"),
|
||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -111,13 +108,10 @@ class DatasetConfigManager:
|
||||
|
||||
# dataset_configs
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {'retrieval_model': 'single'}
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {
|
||||
"strategy": "router",
|
||||
"datasets": []
|
||||
}
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
@ -125,8 +119,9 @@ class DatasetConfigManager:
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
need_manual_query_datasets = (config.get("dataset_configs")
|
||||
and config["dataset_configs"].get("datasets", {}).get("datasets"))
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
).get("datasets")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
@ -148,10 +143,7 @@ class DatasetConfigManager:
|
||||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {
|
||||
"enabled": False,
|
||||
"tools": []
|
||||
}
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
@ -175,7 +167,7 @@ class DatasetConfigManager:
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
has_datasets = False
|
||||
if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
|
||||
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}:
|
||||
for tool in config["agent_mode"]["tools"]:
|
||||
key = list(tool.keys())[0]
|
||||
if key == "dataset":
|
||||
@ -188,7 +180,7 @@ class DatasetConfigManager:
|
||||
if not isinstance(tool_item["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if 'id' not in tool_item:
|
||||
if "id" not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
|
||||
@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager
|
||||
|
||||
class ModelConfigConverter:
|
||||
@classmethod
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig,
|
||||
skip_check: bool = False) \
|
||||
-> ModelConfigWithCredentialsEntity:
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param app_config: app config
|
||||
@ -25,9 +23,7 @@ class ModelConfigConverter:
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=app_config.tenant_id,
|
||||
provider=model_config.provider,
|
||||
model_type=ModelType.LLM
|
||||
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
@ -38,8 +34,7 @@ class ModelConfigConverter:
|
||||
|
||||
# check model credentials
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_config.model
|
||||
model_type=ModelType.LLM, model=model_config.model
|
||||
)
|
||||
|
||||
if model_credentials is None:
|
||||
@ -51,8 +46,7 @@ class ModelConfigConverter:
|
||||
if not skip_check:
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model,
|
||||
model_type=ModelType.LLM
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
@ -69,24 +63,18 @@ class ModelConfigConverter:
|
||||
# model config
|
||||
completion_params = model_config.parameters
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
mode_enum = model_type_instance.get_model_mode(
|
||||
model=model_config.model,
|
||||
credentials=model_credentials
|
||||
)
|
||||
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_config.model,
|
||||
model_credentials
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
if not skip_check and not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
@ -13,23 +13,23 @@ class ModelConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
# model config
|
||||
model_config = config.get('model')
|
||||
model_config = config.get("model")
|
||||
|
||||
if not model_config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
completion_params = model_config.get('completion_params')
|
||||
completion_params = model_config.get("completion_params")
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.get('mode')
|
||||
model_mode = model_config.get("mode")
|
||||
|
||||
return ModelConfigEntity(
|
||||
provider=config['model']['provider'],
|
||||
model=config['model']['name'],
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
mode=model_mode,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
@ -43,7 +43,7 @@ class ModelConfigManager:
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
"""
|
||||
if 'model' not in config:
|
||||
if "model" not in config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
if not isinstance(config["model"], dict):
|
||||
@ -52,17 +52,16 @@ class ModelConfigManager:
|
||||
# model.provider
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
if "name" not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"],
|
||||
model_type=ModelType.LLM
|
||||
provider=config["model"]["provider"], model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not models:
|
||||
@ -80,12 +79,12 @@ class ModelConfigManager:
|
||||
|
||||
# model.mode
|
||||
if model_mode:
|
||||
config['model']["mode"] = model_mode
|
||||
config["model"]["mode"] = model_mode
|
||||
else:
|
||||
config['model']["mode"] = "completion"
|
||||
config["model"]["mode"] = "completion"
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
if "completion_params" not in config["model"]:
|
||||
raise ValueError("model.completion_params is required")
|
||||
|
||||
config["model"]["completion_params"] = cls.validate_model_completion_params(
|
||||
@ -101,7 +100,7 @@ class ModelConfigManager:
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
if "stop" not in cp:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
||||
@ -14,39 +14,33 @@ class PromptTemplateConfigManager:
|
||||
if not config.get("prompt_type"):
|
||||
raise ValueError("prompt_type is required")
|
||||
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
|
||||
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
simple_prompt_template = config.get("pre_prompt", "")
|
||||
return PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
simple_prompt_template=simple_prompt_template
|
||||
)
|
||||
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
|
||||
else:
|
||||
advanced_chat_prompt_template = None
|
||||
chat_prompt_config = config.get("chat_prompt_config", {})
|
||||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
chat_prompt_messages.append({
|
||||
"text": message["text"],
|
||||
"role": PromptMessageRole.value_of(message["role"])
|
||||
})
|
||||
chat_prompt_messages.append(
|
||||
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
)
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
|
||||
messages=chat_prompt_messages
|
||||
)
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = config.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
'prompt': completion_prompt_config['prompt']['text'],
|
||||
"prompt": completion_prompt_config["prompt"]["text"],
|
||||
}
|
||||
|
||||
if 'conversation_histories_role' in completion_prompt_config:
|
||||
completion_prompt_template_params['role_prefix'] = {
|
||||
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
|
||||
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
|
||||
if "conversation_histories_role" in completion_prompt_config:
|
||||
completion_prompt_template_params["role_prefix"] = {
|
||||
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
||||
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
@ -56,7 +50,7 @@ class PromptTemplateConfigManager:
|
||||
return PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -72,7 +66,7 @@ class PromptTemplateConfigManager:
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
|
||||
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
||||
if config['prompt_type'] not in prompt_type_vals:
|
||||
if config["prompt_type"] not in prompt_type_vals:
|
||||
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
|
||||
|
||||
# chat_prompt_config
|
||||
@ -89,27 +83,28 @@ class PromptTemplateConfigManager:
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required "
|
||||
"when prompt_type is advanced")
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
|
||||
raise ValueError(
|
||||
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced"
|
||||
)
|
||||
|
||||
model_mode_vals = [mode.value for mode in ModelMode]
|
||||
if config['model']["mode"] not in model_mode_vals:
|
||||
if config["model"]["mode"] not in model_mode_vals:
|
||||
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
|
||||
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
|
||||
|
||||
if not user_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
|
||||
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
|
||||
|
||||
if config['model']["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config['chat_prompt_config']['prompt']
|
||||
if config["model"]["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config["chat_prompt_config"]["prompt"]
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
|
||||
@ -16,51 +16,49 @@ class BasicVariablesConfigManager:
|
||||
variable_entities = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = config.get('external_data_tools', [])
|
||||
external_data_tools = config.get("external_data_tools", [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=external_data_tool['variable'],
|
||||
type=external_data_tool['type'],
|
||||
config=external_data_tool['config']
|
||||
variable=external_data_tool["variable"],
|
||||
type=external_data_tool["type"],
|
||||
config=external_data_tool["config"],
|
||||
)
|
||||
)
|
||||
|
||||
# variables and external_data_tools
|
||||
for variables in config.get('user_input_form', []):
|
||||
for variables in config.get("user_input_form", []):
|
||||
variable_type = list(variables.keys())[0]
|
||||
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
||||
variable = variables[variable_type]
|
||||
if 'config' not in variable:
|
||||
if "config" not in variable:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=variable['variable'],
|
||||
type=variable['type'],
|
||||
config=variable['config']
|
||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||
)
|
||||
)
|
||||
elif variable_type in [
|
||||
elif variable_type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.SELECT,
|
||||
]:
|
||||
}:
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get('variable'),
|
||||
description=variable.get('description'),
|
||||
label=variable.get('label'),
|
||||
required=variable.get('required', False),
|
||||
max_length=variable.get('max_length'),
|
||||
options=variable.get('options'),
|
||||
default=variable.get('default'),
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description"),
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options"),
|
||||
default=variable.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
@ -99,17 +97,17 @@ class BasicVariablesConfigManager:
|
||||
variables = []
|
||||
for item in config["user_input_form"]:
|
||||
key = list(item.keys())[0]
|
||||
if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
|
||||
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}:
|
||||
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
||||
|
||||
form_item = item[key]
|
||||
if 'label' not in form_item:
|
||||
if "label" not in form_item:
|
||||
raise ValueError("label is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["label"], str):
|
||||
raise ValueError("label in user_input_form must be of string type")
|
||||
|
||||
if 'variable' not in form_item:
|
||||
if "variable" not in form_item:
|
||||
raise ValueError("variable is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["variable"], str):
|
||||
@ -117,26 +115,24 @@ class BasicVariablesConfigManager:
|
||||
|
||||
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
|
||||
if pattern.match(form_item["variable"]) is None:
|
||||
raise ValueError("variable in user_input_form must be a string, "
|
||||
"and cannot start with a number")
|
||||
raise ValueError("variable in user_input_form must be a string, and cannot start with a number")
|
||||
|
||||
variables.append(form_item["variable"])
|
||||
|
||||
if 'required' not in form_item or not form_item["required"]:
|
||||
if "required" not in form_item or not form_item["required"]:
|
||||
form_item["required"] = False
|
||||
|
||||
if not isinstance(form_item["required"], bool):
|
||||
raise ValueError("required in user_input_form must be of boolean type")
|
||||
|
||||
if key == "select":
|
||||
if 'options' not in form_item or not form_item["options"]:
|
||||
if "options" not in form_item or not form_item["options"]:
|
||||
form_item["options"] = []
|
||||
|
||||
if not isinstance(form_item["options"], list):
|
||||
raise ValueError("options in user_input_form must be a list of strings")
|
||||
|
||||
if "default" in form_item and form_item['default'] \
|
||||
and form_item["default"] not in form_item["options"]:
|
||||
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
|
||||
raise ValueError("default value in user_input_form must be in the options list")
|
||||
|
||||
return config, ["user_input_form"]
|
||||
@ -168,10 +164,6 @@ class BasicVariablesConfigManager:
|
||||
typ = tool["type"]
|
||||
config = tool["config"]
|
||||
|
||||
ExternalDataToolFactory.validate_config(
|
||||
name=typ,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
|
||||
|
||||
return config, ["external_data_tools"]
|
||||
|
||||
@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel):
|
||||
"""
|
||||
Model Config Entity.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
mode: Optional[str] = None
|
||||
@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel):
|
||||
"""
|
||||
Advanced Chat Message Entity.
|
||||
"""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
|
||||
@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Chat Prompt Template Entity.
|
||||
"""
|
||||
|
||||
messages: list[AdvancedChatMessageEntity]
|
||||
|
||||
|
||||
@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Role Prefix Entity.
|
||||
"""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel):
|
||||
Prompt Type.
|
||||
'simple', 'advanced'
|
||||
"""
|
||||
SIMPLE = 'simple'
|
||||
ADVANCED = 'advanced'
|
||||
|
||||
SIMPLE = "simple"
|
||||
ADVANCED = "advanced"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'PromptType':
|
||||
def value_of(cls, value: str) -> "PromptType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid prompt type value {value}')
|
||||
raise ValueError(f"invalid prompt type value {value}")
|
||||
|
||||
prompt_type: PromptType
|
||||
simple_prompt_template: Optional[str] = None
|
||||
@ -87,7 +92,7 @@ class VariableEntityType(str, Enum):
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external-data-tool"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
External Data Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
Dataset Retrieve Strategy.
|
||||
'single' or 'multiple'
|
||||
"""
|
||||
SINGLE = 'single'
|
||||
MULTIPLE = 'multiple'
|
||||
|
||||
SINGLE = "single"
|
||||
MULTIPLE = "multiple"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'RetrieveStrategy':
|
||||
def value_of(cls, value: str) -> "RetrieveStrategy":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid retrieve strategy value {value}')
|
||||
raise ValueError(f"invalid retrieve strategy value {value}")
|
||||
|
||||
query_variable: Optional[str] = None # Only when app mode is completion
|
||||
|
||||
retrieve_strategy: RetrieveStrategy
|
||||
top_k: Optional[int] = None
|
||||
score_threshold: Optional[float] = .0
|
||||
rerank_mode: Optional[str] = 'reranking_model'
|
||||
score_threshold: Optional[float] = 0.0
|
||||
rerank_mode: Optional[str] = "reranking_model"
|
||||
reranking_model: Optional[dict] = None
|
||||
weights: Optional[dict] = None
|
||||
reranking_enabled: Optional[bool] = True
|
||||
|
||||
|
||||
|
||||
|
||||
class DatasetEntity(BaseModel):
|
||||
"""
|
||||
Dataset Config Entity.
|
||||
"""
|
||||
|
||||
dataset_ids: list[str]
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
|
||||
@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
voice: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel):
|
||||
"""
|
||||
Tracing Config Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
tracing_provider: str
|
||||
|
||||
|
||||
|
||||
|
||||
class AppAdditionalFeatures(BaseModel):
|
||||
file_upload: Optional[FileExtraConfig] = None
|
||||
opening_statement: Optional[str] = None
|
||||
@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel):
|
||||
text_to_speech: Optional[TextToSpeechEntity] = None
|
||||
trace_config: Optional[TracingConfigEntity] = None
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""
|
||||
Application Config Entity.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum):
|
||||
"""
|
||||
App Model Config From.
|
||||
"""
|
||||
ARGS = 'args'
|
||||
APP_LATEST_CONFIG = 'app-latest-config'
|
||||
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
|
||||
|
||||
ARGS = "args"
|
||||
APP_LATEST_CONFIG = "app-latest-config"
|
||||
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
|
||||
|
||||
|
||||
class EasyUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Easy UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
app_model_config_dict: dict
|
||||
@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Workflow UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
|
||||
@ -13,21 +13,19 @@ class FileUploadConfigManager:
|
||||
:param config: model config args
|
||||
:param is_vision: if True, the feature is vision feature
|
||||
"""
|
||||
file_upload_dict = config.get('file_upload')
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get('image'):
|
||||
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
||||
if file_upload_dict.get("image"):
|
||||
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
|
||||
image_config = {
|
||||
'number_limits': file_upload_dict['image']['number_limits'],
|
||||
'transfer_methods': file_upload_dict['image']['transfer_methods']
|
||||
"number_limits": file_upload_dict["image"]["number_limits"],
|
||||
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
|
||||
}
|
||||
|
||||
if is_vision:
|
||||
image_config['detail'] = file_upload_dict['image']['detail']
|
||||
image_config["detail"] = file_upload_dict["image"]["detail"]
|
||||
|
||||
return FileExtraConfig(
|
||||
image_config=image_config
|
||||
)
|
||||
return FileExtraConfig(image_config=image_config)
|
||||
|
||||
return None
|
||||
|
||||
@ -49,21 +47,21 @@ class FileUploadConfigManager:
|
||||
if not config["file_upload"].get("image"):
|
||||
config["file_upload"]["image"] = {"enabled": False}
|
||||
|
||||
if config['file_upload']['image']['enabled']:
|
||||
number_limits = config['file_upload']['image']['number_limits']
|
||||
if config["file_upload"]["image"]["enabled"]:
|
||||
number_limits = config["file_upload"]["image"]["number_limits"]
|
||||
if number_limits < 1 or number_limits > 6:
|
||||
raise ValueError("number_limits must be in [1, 6]")
|
||||
|
||||
if is_vision:
|
||||
detail = config['file_upload']['image']['detail']
|
||||
if detail not in ['high', 'low']:
|
||||
detail = config["file_upload"]["image"]["detail"]
|
||||
if detail not in {"high", "low"}:
|
||||
raise ValueError("detail must be in ['high', 'low']")
|
||||
|
||||
transfer_methods = config['file_upload']['image']['transfer_methods']
|
||||
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
||||
if not isinstance(transfer_methods, list):
|
||||
raise ValueError("transfer_methods must be of list type")
|
||||
for method in transfer_methods:
|
||||
if method not in ['remote_url', 'local_file']:
|
||||
if method not in {"remote_url", "local_file"}:
|
||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||
|
||||
return config, ["file_upload"]
|
||||
|
||||
@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
more_like_this = False
|
||||
more_like_this_dict = config.get('more_like_this')
|
||||
more_like_this_dict = config.get("more_like_this")
|
||||
if more_like_this_dict:
|
||||
if more_like_this_dict.get('enabled'):
|
||||
if more_like_this_dict.get("enabled"):
|
||||
more_like_this = True
|
||||
|
||||
return more_like_this
|
||||
@ -22,9 +22,7 @@ class MoreLikeThisConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("more_like_this"):
|
||||
config["more_like_this"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["more_like_this"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["more_like_this"], dict):
|
||||
raise ValueError("more_like_this must be of dict type")
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
@ -9,10 +7,10 @@ class OpeningStatementConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
# opening statement
|
||||
opening_statement = config.get('opening_statement')
|
||||
opening_statement = config.get("opening_statement")
|
||||
|
||||
# suggested questions
|
||||
suggested_questions_list = config.get('suggested_questions')
|
||||
suggested_questions_list = config.get("suggested_questions")
|
||||
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
|
||||
@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get('retriever_resource')
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
if retriever_resource_dict.get('enabled'):
|
||||
if retriever_resource_dict.get("enabled"):
|
||||
show_retrieve_source = True
|
||||
|
||||
return show_retrieve_source
|
||||
@ -17,9 +17,7 @@ class RetrievalResourceConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("retriever_resource"):
|
||||
config["retriever_resource"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["retriever_resource"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["retriever_resource"], dict):
|
||||
raise ValueError("retriever_resource must be of dict type")
|
||||
|
||||
@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
speech_to_text = False
|
||||
speech_to_text_dict = config.get('speech_to_text')
|
||||
speech_to_text_dict = config.get("speech_to_text")
|
||||
if speech_to_text_dict:
|
||||
if speech_to_text_dict.get('enabled'):
|
||||
if speech_to_text_dict.get("enabled"):
|
||||
speech_to_text = True
|
||||
|
||||
return speech_to_text
|
||||
@ -22,9 +22,7 @@ class SpeechToTextConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("speech_to_text"):
|
||||
config["speech_to_text"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["speech_to_text"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["speech_to_text"], dict):
|
||||
raise ValueError("speech_to_text must be of dict type")
|
||||
|
||||
@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
suggested_questions_after_answer = False
|
||||
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
|
||||
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
|
||||
if suggested_questions_after_answer_dict:
|
||||
if suggested_questions_after_answer_dict.get('enabled'):
|
||||
if suggested_questions_after_answer_dict.get("enabled"):
|
||||
suggested_questions_after_answer = True
|
||||
|
||||
return suggested_questions_after_answer
|
||||
@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("suggested_questions_after_answer"):
|
||||
config["suggested_questions_after_answer"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["suggested_questions_after_answer"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"], dict):
|
||||
raise ValueError("suggested_questions_after_answer must be of dict type")
|
||||
|
||||
if "enabled" not in config["suggested_questions_after_answer"] or not \
|
||||
config["suggested_questions_after_answer"]["enabled"]:
|
||||
if (
|
||||
"enabled" not in config["suggested_questions_after_answer"]
|
||||
or not config["suggested_questions_after_answer"]["enabled"]
|
||||
):
|
||||
config["suggested_questions_after_answer"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
|
||||
@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
text_to_speech = None
|
||||
text_to_speech_dict = config.get('text_to_speech')
|
||||
text_to_speech_dict = config.get("text_to_speech")
|
||||
if text_to_speech_dict:
|
||||
if text_to_speech_dict.get('enabled'):
|
||||
if text_to_speech_dict.get("enabled"):
|
||||
text_to_speech = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get('enabled'),
|
||||
voice=text_to_speech_dict.get('voice'),
|
||||
language=text_to_speech_dict.get('language'),
|
||||
enabled=text_to_speech_dict.get("enabled"),
|
||||
voice=text_to_speech_dict.get("voice"),
|
||||
language=text_to_speech_dict.get("language"),
|
||||
)
|
||||
|
||||
return text_to_speech
|
||||
@ -29,11 +29,7 @@ class TextToSpeechConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("text_to_speech"):
|
||||
config["text_to_speech"] = {
|
||||
"enabled": False,
|
||||
"voice": "",
|
||||
"language": ""
|
||||
}
|
||||
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
|
||||
|
||||
if not isinstance(config["text_to_speech"], dict):
|
||||
raise ValueError("text_to_speech must be of dict type")
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
|
||||
"""
|
||||
Advanced Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
features_dict = workflow.features_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
workflow_id=workflow.id,
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=features_dict
|
||||
),
|
||||
variables=WorkflowVariablesConfigManager.convert(
|
||||
workflow=workflow
|
||||
),
|
||||
additional_features=cls.convert_features(features_dict, app_mode)
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
|
||||
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
||||
additional_features=cls.convert_features(features_dict, app_mode),
|
||||
)
|
||||
|
||||
return app_config
|
||||
@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config,
|
||||
is_vision=False
|
||||
config=config, is_vision=False
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
only_structure_validate=only_structure_validate
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
|
||||
|
||||
@ -4,12 +4,10 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@ -17,36 +15,54 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
):
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -57,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', False)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get('conversation_id')
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
conversation = self._get_conversation_by_user(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
@ -116,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
@ -126,15 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True):
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -146,43 +152,29 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras=extras,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
)
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
@ -191,32 +183,42 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
conversation=None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _generate(self, *,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation | None = None,
|
||||
stream: bool = True):
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
:param application_generate_entity: application generate entity
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
# db.session.refresh(conversation)
|
||||
db.session.refresh(conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@ -225,73 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
conversation.dialogue_count += 1
|
||||
|
||||
conversation_id = conversation.id
|
||||
conversation_dialogue_count = conversation.dialogue_count
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
contexts.workflow_variable_pool.set(variable_pool)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'message_id': message.id,
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": contextvars.copy_context(),
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@ -306,16 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message_id: str,
|
||||
context: contextvars.Context) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@ -329,40 +280,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
runner = AdvancedChatAppRunner()
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
@ -408,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
@ -21,14 +21,11 @@ class AudioTrunk:
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user="responding_tts",
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
|
||||
)
|
||||
|
||||
|
||||
@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b''))
|
||||
audio_queue.put(AudioTrunk("finish", b""))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ''
|
||||
self.msg_text = ""
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self.match = re.compile(r'[。.!?]')
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get('value') for voice in self.voices]
|
||||
values = [voice.get("value") for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get('value')
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
|
||||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||
self.model_instance, self.tenant_id, self.voice)
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
@ -94,28 +90,27 @@ class AppGeneratorTTSPublisher:
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
self.msg_text += message.event.outputs.get('output', '')
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = ''.join(sentence_arr)
|
||||
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||
self.model_instance,
|
||||
self.tenant_id,
|
||||
self.voice)
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ''
|
||||
self.msg_text = ""
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
|
||||
@ -1,145 +1,197 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.moderation.base import ModerationException
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationError
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models import App, Message, Workflow
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppRunner(AppRunner):
|
||||
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
AdvancedChat Application Runner
|
||||
"""
|
||||
|
||||
def run(
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
super().__init__(queue_manager)
|
||||
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.message = message
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
):
|
||||
return
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
):
|
||||
return
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id,
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
):
|
||||
return
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id,
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
self.conversation.dialogue_count += 1
|
||||
|
||||
conversation_dialogue_count = self.conversation.dialogue_count
|
||||
db.session.commit()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
)
|
||||
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
@ -148,7 +200,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param queue_manager: application queue manager
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
@ -166,31 +217,20 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=str(e),
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
|
||||
)
|
||||
except ModerationError as e:
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(
|
||||
self,
|
||||
app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||
) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param queue_manager: application queue manager
|
||||
:param app_generate_entity: application generate entity
|
||||
"""
|
||||
# annotation reply
|
||||
@ -203,37 +243,21 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
|
||||
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=annotation_reply.content,
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
|
||||
self._complete_with_stream_output(
|
||||
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _stream_output(
|
||||
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
|
||||
) -> None:
|
||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param text: text
|
||||
:param stream: stream
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
|
||||
self._publish_event(QueueTextChunkEvent(text=text))
|
||||
|
||||
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
|
||||
self._publish_event(QueueStopEvent(stopped_by=stopped_by))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user