Compare commits

..

1 Commits

Author SHA1 Message Date
Yi
2474dbdff0 fix the tooltip for the knowledge base's firecrawl max depth attribute 2024-08-28 17:09:30 +08:00
1632 changed files with 32913 additions and 56546 deletions

View File

@ -20,7 +20,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v45 uses: tj-actions/changed-files@v44
with: with:
files: api/** files: api/**
@ -66,7 +66,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v45 uses: tj-actions/changed-files@v44
with: with:
files: web/** files: web/**
@ -97,7 +97,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v45 uses: tj-actions/changed-files@v44
with: with:
files: | files: |
**.sh **.sh
@ -107,7 +107,7 @@ jobs:
dev/** dev/**
- name: Super-linter - name: Super-linter
uses: super-linter/super-linter/slim@v7 uses: super-linter/super-linter/slim@v6
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
env: env:
BASH_SEVERITY: warning BASH_SEVERITY: warning

View File

@ -1,54 +0,0 @@
name: Check i18n Files and Create PR
on:
pull_request:
types: [closed]
branches: [main]
jobs:
check-and-update:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
- name: Check for file changes in i18n/en-US
id: check_files
run: |
recent_commit_sha=$(git rev-parse HEAD)
second_recent_commit_sha=$(git rev-parse HEAD~1)
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v2
with:
node-version: 'lts/*'
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.
branch: chore/automated-i18n-updates

9
.gitignore vendored
View File

@ -153,9 +153,6 @@ docker-legacy/volumes/etcd/*
docker-legacy/volumes/minio/* docker-legacy/volumes/minio/*
docker-legacy/volumes/milvus/* docker-legacy/volumes/milvus/*
docker-legacy/volumes/chroma/* docker-legacy/volumes/chroma/*
docker-legacy/volumes/opensearch/data/*
docker-legacy/volumes/pgvectors/data/*
docker-legacy/volumes/pgvector/data/*
docker/volumes/app/storage/* docker/volumes/app/storage/*
docker/volumes/certbot/* docker/volumes/certbot/*
@ -167,12 +164,6 @@ docker/volumes/etcd/*
docker/volumes/minio/* docker/volumes/minio/*
docker/volumes/milvus/* docker/volumes/milvus/*
docker/volumes/chroma/* docker/volumes/chroma/*
docker/volumes/opensearch/data/*
docker/volumes/myscale/data/*
docker/volumes/myscale/log/*
docker/volumes/unstructured/*
docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf
docker/middleware.env docker/middleware.env

View File

@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
## Before you jump in ## Before you jump in
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: [Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
### Feature requests: ### Feature requests:

View File

@ -8,7 +8,7 @@
## 在开始之前 ## 在开始之前
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: [查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
### 功能请求: ### 功能请求:
@ -36,7 +36,7 @@
| 被团队成员标记为高优先级的功能 | 高优先级 | | 被团队成员标记为高优先级的功能 | 高优先级 |
| 在 [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 | | 在 [community feedback board](https://github.com/langgenius/dify/discussions/categories/feedbacks) 内反馈的常见功能请求 | 中等优先级 |
| 非核心功能和小幅改进 | 低优先级 | | 非核心功能和小幅改进 | 低优先级 |
| 有价值不紧急 | 未来功能 | | 有价值不紧急 | 未来功能 |
### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正): ### 其他任何事情(例如 bug 报告、性能优化、拼写错误更正):
* 立即开始编码。 * 立即开始编码。
@ -138,7 +138,7 @@ Dify 的后端使用 Python 编写,使用 [Flask](https://flask.palletsproject
├── models // 描述数据模型和 API 响应的形状 ├── models // 描述数据模型和 API 响应的形状
├── public // 如 favicon 等元资源 ├── public // 如 favicon 等元资源
├── service // 定义 API 操作的形状 ├── service // 定义 API 操作的形状
├── test ├── test
├── types // 函数参数和返回值的描述 ├── types // 函数参数和返回值的描述
└── utils // 共享的实用函数 └── utils // 共享的实用函数
``` ```

View File

@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
## 飛び込む前に ## 飛び込む前に
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 [既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
### 機能リクエスト ### 機能リクエスト

View File

@ -8,7 +8,7 @@ Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [
## Trước khi bắt đầu ## Trước khi bắt đầu
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại: [Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
### Yêu cầu tính năng: ### Yêu cầu tính năng:

View File

@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer: 1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment. a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations. - Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components. b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.

View File

@ -39,7 +39,7 @@ DB_DATABASE=dify
# Storage configuration # Storage configuration
# use for store upload files, private keys... # use for store upload files, private keys...
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos # storage type: local, s3, azure-blob, google-storage
STORAGE_TYPE=local STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false S3_USE_AWS_MANAGED_IAM=false
@ -60,8 +60,7 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
ALIYUN_OSS_ENDPOINT=your-endpoint ALIYUN_OSS_ENDPOINT=your-endpoint
ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration # Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
@ -73,12 +72,6 @@ TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme TENCENT_COS_SCHEME=your-scheme
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
# OCI Storage configuration # OCI Storage configuration
OCI_ENDPOINT=your-endpoint OCI_ENDPOINT=your-endpoint
OCI_BUCKET_NAME=your-bucket-name OCI_BUCKET_NAME=your-bucket-name
@ -86,13 +79,6 @@ OCI_ACCESS_KEY=your-access-key
OCI_SECRET_KEY=your-secret-key OCI_SECRET_KEY=your-secret-key
OCI_REGION=your-region OCI_REGION=your-region
# Volcengine tos Storage configuration
VOLCENGINE_TOS_ENDPOINT=your-endpoint
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
VOLCENGINE_TOS_REGION=your-region
# CORS configuration # CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
@ -114,10 +100,11 @@ QDRANT_GRPC_ENABLED=false
QDRANT_GRPC_PORT=6334 QDRANT_GRPC_PORT=6334
# Milvus configuration # Milvus configuration
MILVUS_URI=http://127.0.0.1:19530 MILVUS_HOST=127.0.0.1
MILVUS_TOKEN= MILVUS_PORT=19530
MILVUS_USER=root MILVUS_USER=root
MILVUS_PASSWORD=Milvus MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# MyScale configuration # MyScale configuration
MYSCALE_HOST=127.0.0.1 MYSCALE_HOST=127.0.0.1

View File

@ -55,7 +55,7 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \ && apt-get update \
# For Security # For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get autoremove -y \ && apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*

View File

@ -164,7 +164,7 @@ def initialize_extensions(app):
@login_manager.request_loader @login_manager.request_loader
def load_user_from_request(request_from_flask_login): def load_user_from_request(request_from_flask_login):
"""Load user based on the request.""" """Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}: if request.blueprint not in ["console", "inner_api"]:
return None return None
# Check if the user_id contains a dot, indicating the old format # Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")

View File

@ -104,7 +104,7 @@ def reset_email(email, new_email, email_confirm):
) )
@click.confirmation_option( @click.confirmation_option(
prompt=click.style( prompt=click.style(
"Are you sure you want to reset encrypt key pair? this operation cannot be rolled back!", fg="red" "Are you sure you want to reset encrypt key pair?" " this operation cannot be rolled back!", fg="red"
) )
) )
def reset_encrypt_key_pair(): def reset_encrypt_key_pair():
@ -131,7 +131,7 @@ def reset_encrypt_key_pair():
click.echo( click.echo(
click.style( click.style(
"Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id), "Congratulations! " "the asymmetric key pair of workspace {} has been reset.".format(tenant.id),
fg="green", fg="green",
) )
) )
@ -140,9 +140,9 @@ def reset_encrypt_key_pair():
@click.command("vdb-migrate", help="migrate vector db.") @click.command("vdb-migrate", help="migrate vector db.")
@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.")
def vdb_migrate(scope: str): def vdb_migrate(scope: str):
if scope in {"knowledge", "all"}: if scope in ["knowledge", "all"]:
migrate_knowledge_vector_database() migrate_knowledge_vector_database()
if scope in {"annotation", "all"}: if scope in ["annotation", "all"]:
migrate_annotation_vector_database() migrate_annotation_vector_database()
@ -275,7 +275,8 @@ def migrate_knowledge_vector_database():
for dataset in datasets: for dataset in datasets:
total_count = total_count + 1 total_count = total_count + 1
click.echo( click.echo(
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." f"Processing the {total_count} dataset {dataset.id}. "
+ f"{create_count} created, {skipped_count} skipped."
) )
try: try:
click.echo("Create dataset vdb index: {}".format(dataset.id)) click.echo("Create dataset vdb index: {}".format(dataset.id))
@ -410,8 +411,7 @@ def migrate_knowledge_vector_database():
try: try:
click.echo( click.echo(
click.style( click.style(
f"Start to created vector index with {len(documents)} documents of {segments_count}" f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
f" segments for dataset {dataset.id}.",
fg="green", fg="green",
) )
) )
@ -559,9 +559,8 @@ def add_qdrant_doc_id_index(field: str):
@click.command("create-tenant", help="Create account and tenant.") @click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="The email address of the tenant account.") @click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.") @click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): def create_tenant(email: str, language: Optional[str] = None):
""" """
Create tenant account Create tenant account
""" """
@ -581,19 +580,17 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
if language not in languages: if language not in languages:
language = "en-US" language = "en-US"
name = name.strip()
# generate random password # generate random password
new_password = secrets.token_urlsafe(16) new_password = secrets.token_urlsafe(16)
# register account # register account
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
TenantService.create_owner_tenant_if_not_exist(account, name) TenantService.create_owner_tenant_if_not_exist(account)
click.echo( click.echo(
click.style( click.style(
"Congratulations! Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password), "Congratulations! Account and tenant created.\n" "Account: {}\nPassword: {}".format(email, new_password),
fg="green", fg="green",
) )
) )

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional from typing import Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -46,7 +46,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
""" """
CODE_EXECUTION_ENDPOINT: HttpUrl = Field( CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="endpoint URL of code execution service", description="endpoint URL of code execution servcie",
default="http://sandbox:8194", default="http://sandbox:8194",
) )
@ -129,12 +129,12 @@ class EndpointConfig(BaseSettings):
) )
SERVICE_API_URL: str = Field( SERVICE_API_URL: str = Field(
description="Service API Url prefix. used to display Service API Base Url to the front-end.", description="Service API Url prefix." "used to display Service API Base Url to the front-end.",
default="", default="",
) )
APP_WEB_URL: str = Field( APP_WEB_URL: str = Field(
description="WebApp Url prefix. used to display WebAPP API Base Url to the front-end.", description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.",
default="", default="",
) )
@ -217,17 +217,20 @@ class HttpConfig(BaseSettings):
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[ HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field(
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request") description="",
] = 10 default=300,
)
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[ HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field(
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request") description="",
] = 60 default=600,
)
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[ HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field(
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request") description="",
] = 20 default=600,
)
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field( HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="", description="",
@ -272,7 +275,7 @@ class LoggingConfig(BaseSettings):
""" """
LOG_LEVEL: str = Field( LOG_LEVEL: str = Field(
description="Log output level, default to INFO. It is recommended to set it to ERROR for production.", description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.",
default="INFO", default="INFO",
) )
@ -415,7 +418,7 @@ class MailConfig(BaseSettings):
""" """
MAIL_TYPE: Optional[str] = Field( MAIL_TYPE: Optional[str] = Field(
description="Mail provider type name, default to None, available values are `smtp` and `resend`.", description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
default=None, default=None,
) )

View File

@ -1,7 +1,7 @@
from typing import Any, Optional from typing import Any, Optional
from urllib.parse import quote_plus from urllib.parse import quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig from configs.middleware.cache.redis_config import RedisConfig
@ -9,10 +9,8 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
@ -159,21 +157,6 @@ class CeleryConfig(DatabaseConfig):
default=None, default=None,
) )
CELERY_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
description="Redis Sentinel master name",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)
@computed_field @computed_field
@property @property
def CELERY_RESULT_BACKEND(self) -> str | None: def CELERY_RESULT_BACKEND(self) -> str | None:
@ -201,8 +184,6 @@ class MiddlewareConfig(
AzureBlobStorageConfig, AzureBlobStorageConfig,
GoogleCloudStorageConfig, GoogleCloudStorageConfig,
TencentCloudCOSStorageConfig, TencentCloudCOSStorageConfig,
HuaweiCloudOBSStorageConfig,
VolcengineTOSStorageConfig,
S3StorageConfig, S3StorageConfig,
OCIStorageConfig, OCIStorageConfig,
# configs of vdb and vdb providers # configs of vdb and vdb providers

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -38,33 +38,3 @@ class RedisConfig(BaseSettings):
description="whether to use SSL for Redis connection", description="whether to use SSL for Redis connection",
default=False, default=False,
) )
REDIS_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
REDIS_SENTINELS: Optional[str] = Field(
description="Redis Sentinel nodes",
default=None,
)
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
description="Redis Sentinel service name",
default=None,
)
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
description="Redis Sentinel username",
default=None,
)
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
description="Redis Sentinel password",
default=None,
)
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)

View File

@ -38,8 +38,3 @@ class AliyunOSSStorageConfig(BaseSettings):
description="Aliyun OSS authentication version", description="Aliyun OSS authentication version",
default=None, default=None,
) )
ALIYUN_OSS_PATH: Optional[str] = Field(
description="Aliyun OSS path",
default=None,
)

View File

@ -1,29 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class HuaweiCloudOBSStorageConfig(BaseModel):
"""
Huawei Cloud OBS storage configs
"""
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
description="Huawei Cloud OBS bucket name",
default=None,
)
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Access key",
default=None,
)
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Secret key",
default=None,
)
HUAWEI_OBS_SERVER: Optional[str] = Field(
description="Huawei Cloud OBS server URL",
default=None,
)

View File

@ -1,34 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class VolcengineTOSStorageConfig(BaseModel):
"""
Volcengine tos storage configs
"""
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
description="Volcengine TOS Bucket Name",
default=None,
)
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
description="Volcengine TOS Access Key",
default=None,
)
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
description="Volcengine TOS Secret Key",
default=None,
)
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
description="Volcengine TOS Endpoint URL",
default=None,
)
VOLCENGINE_TOS_REGION: Optional[str] = Field(
description="Volcengine TOS Region",
default=None,
)

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from pydantic import Field from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -9,14 +9,14 @@ class MilvusConfig(BaseSettings):
Milvus configs Milvus configs
""" """
MILVUS_URI: Optional[str] = Field( MILVUS_HOST: Optional[str] = Field(
description="Milvus uri", description="Milvus host",
default="http://127.0.0.1:19530", default=None,
) )
MILVUS_TOKEN: Optional[str] = Field( MILVUS_PORT: PositiveInt = Field(
description="Milvus token", description="Milvus RestFul API port",
default=None, default=9091,
) )
MILVUS_USER: Optional[str] = Field( MILVUS_USER: Optional[str] = Field(
@ -29,6 +29,11 @@ class MilvusConfig(BaseSettings):
default=None, default=None,
) )
MILVUS_SECURE: bool = Field(
description="whether to use SSL connection for Milvus",
default=False,
)
MILVUS_DATABASE: str = Field( MILVUS_DATABASE: str = Field(
description="Milvus database, default to `default`", description="Milvus database, default to `default`",
default="default", default="default",

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.8.3", default="0.7.2",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

File diff suppressed because one or more lines are too long

View File

@ -60,15 +60,23 @@ class InsertExploreAppListApi(Resource):
site = app.site site = app.site
if not site: if not site:
desc = args["desc"] or "" desc = args["desc"] if args["desc"] else ""
copy_right = args["copyright"] or "" copy_right = args["copyright"] if args["copyright"] else ""
privacy_policy = args["privacy_policy"] or "" privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
custom_disclaimer = args["custom_disclaimer"] or "" custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
else: else:
desc = site.description or args["desc"] or "" desc = site.description if site.description else args["desc"] if args["desc"] else ""
copy_right = site.copyright or args["copyright"] or "" copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or "" privacy_policy = (
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
)
custom_disclaimer = (
site.custom_disclaimer
if site.custom_disclaimer
else args["custom_disclaimer"]
if args["custom_disclaimer"]
else ""
)
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()

View File

@ -57,7 +57,7 @@ class BaseApiKeyListResource(Resource):
def post(self, resource_id): def post(self, resource_id):
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
current_key_count = ( current_key_count = (

View File

@ -174,7 +174,6 @@ class AppApi(Resource):
parser.add_argument("icon", type=str, location="json") parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json") parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("max_active_requests", type=int, location="json") parser.add_argument("max_active_requests", type=int, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()

View File

@ -94,15 +94,19 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice") voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else: else:
try: try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)

View File

@ -20,7 +20,7 @@ from fields.conversation_fields import (
conversation_pagination_fields, conversation_pagination_fields,
conversation_with_summary_pagination_fields, conversation_with_summary_pagination_fields,
) )
from libs.helper import DatetimeString from libs.helper import datetime_string
from libs.login import login_required from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
) )
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
) )
@ -201,11 +201,7 @@ class ChatConversationApi(Resource):
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]: query = query.where(Conversation.created_at >= start_datetime_utc)
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]: if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
@ -214,11 +210,7 @@ class ChatConversationApi(Resource):
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]: query = query.where(Conversation.created_at < end_datetime_utc)
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated": if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( query = query.options(joinedload(Conversation.message_annotations)).join(

View File

@ -34,7 +34,6 @@ def parse_app_site_args():
) )
parser.add_argument("prompt_public", type=bool, required=False, location="json") parser.add_argument("prompt_public", type=bool, required=False, location="json")
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args() return parser.parse_args()
@ -69,7 +68,6 @@ class AppSite(Resource):
"customize_token_strategy", "customize_token_strategy",
"prompt_public", "prompt_public",
"show_workflow_steps", "show_workflow_steps",
"use_icon_as_answer_icon",
]: ]:
value = args.get(attr_name) value = args.get(attr_name)
if value is not None: if value is not None:

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import DatetimeString from libs.helper import datetime_string
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models.model import AppMode
@ -25,17 +25,14 @@ class DailyMessageStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(*) AS message_count
COUNT(*) AS message_count FROM messages where app_id = :app_id
FROM """
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -48,7 +45,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -58,10 +55,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []
@ -82,17 +79,14 @@ class DailyConversationStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
COUNT(DISTINCT messages.conversation_id) AS conversation_count FROM messages where app_id = :app_id
FROM """
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -105,7 +99,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -115,10 +109,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []
@ -139,17 +133,14 @@ class DailyTerminalsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count FROM messages where app_id = :app_id
FROM """
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -162,7 +153,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -172,10 +163,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []
@ -196,18 +187,16 @@ class DailyTokenCostStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count, (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
SUM(total_price) AS total_price sum(total_price) as total_price
FROM FROM messages where app_id = :app_id
messages """
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -220,7 +209,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -230,10 +219,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []
@ -256,26 +245,16 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(subquery.message_count) AS interactions
AVG(subquery.message_count) AS interactions FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
FROM FROM conversations c
( JOIN messages m ON c.id = m.conversation_id
SELECT WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
m.conversation_id,
COUNT(m.id) AS message_count
FROM
conversations c
JOIN
messages m
ON c.id = m.conversation_id
WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -288,7 +267,7 @@ FROM
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND c.created_at >= :start" sql_query += " and c.created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -298,19 +277,14 @@ FROM
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND c.created_at < :end" sql_query += " and c.created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += """ sql_query += """
GROUP BY m.conversation_id GROUP BY m.conversation_id) subquery
) subquery LEFT JOIN conversations c on c.id=subquery.conversation_id
LEFT JOIN GROUP BY date
conversations c ORDER BY date"""
ON c.id = subquery.conversation_id
GROUP BY
date
ORDER BY
date"""
response_data = [] response_data = []
@ -333,21 +307,17 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) AS message_count, COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
COUNT(mf.id) AS feedback_count FROM messages m
FROM LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
messages m WHERE m.app_id = :app_id
LEFT JOIN """
message_feedbacks mf
ON mf.message_id=m.id AND mf.rating='like'
WHERE
m.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -360,7 +330,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND m.created_at >= :start" sql_query += " and m.created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -370,10 +340,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND m.created_at < :end" sql_query += " and m.created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []
@ -399,17 +369,16 @@ class AverageResponseTimeStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) AS latency AVG(provider_response_latency) as latency
FROM FROM messages
messages WHERE app_id = :app_id
WHERE """
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -422,7 +391,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -432,10 +401,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []
@ -456,20 +425,17 @@ class TokensPerSecondStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, CASE
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0 WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
END as tokens_per_second END as tokens_per_second
FROM FROM messages
messages WHERE app_id = :app_id"""
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id} arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
@ -482,7 +448,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -492,10 +458,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []

View File

@ -465,6 +465,6 @@ api.add_resource(
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish") api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs") api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource( api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>" DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
) )
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow") api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import DatetimeString from libs.helper import datetime_string
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom from models.workflow import WorkflowRunTriggeredFrom
@ -26,18 +26,16 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
COUNT(id) AS runs FROM workflow_runs
FROM WHERE app_id = :app_id
workflow_runs AND triggered_from = :triggered_from
WHERE """
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
@ -54,7 +52,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -64,10 +62,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []
@ -88,18 +86,16 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count FROM workflow_runs
FROM WHERE app_id = :app_id
workflow_runs AND triggered_from = :triggered_from
WHERE """
app_id = :app_id
AND triggered_from = :triggered_from"""
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
@ -116,7 +112,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -126,10 +122,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []
@ -150,18 +146,18 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT
SUM(workflow_runs.total_tokens) AS token_count date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
FROM SUM(workflow_runs.total_tokens) as token_count
workflow_runs FROM workflow_runs
WHERE WHERE app_id = :app_id
app_id = :app_id AND triggered_from = :triggered_from
AND triggered_from = :triggered_from""" """
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
@ -178,7 +174,7 @@ WHERE
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += " and created_at >= :start"
arg_dict["start"] = start_datetime_utc arg_dict["start"] = start_datetime_utc
if args["end"]: if args["end"]:
@ -188,10 +184,10 @@ WHERE
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += " and created_at < :end"
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += " GROUP BY date order by date"
response_data = [] response_data = []
@ -217,31 +213,27 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
AVG(sub.interactions) AS interactions, SELECT
sub.date AVG(sub.interactions) as interactions,
FROM sub.date
( FROM
SELECT (SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by, c.created_by,
COUNT(c.id) AS interactions COUNT(c.id) AS interactions
FROM FROM workflow_runs c
workflow_runs c WHERE c.app_id = :app_id
WHERE AND c.triggered_from = :triggered_from
c.app_id = :app_id {{start}}
AND c.triggered_from = :triggered_from {{end}}
{{start}} GROUP BY date, c.created_by) sub
{{end}} GROUP BY sub.date
GROUP BY """
date, c.created_by
) sub
GROUP BY
sub.date"""
arg_dict = { arg_dict = {
"tz": account.timezone, "tz": account.timezone,
"app_id": app_model.id, "app_id": app_model.id,
@ -270,7 +262,7 @@ GROUP BY
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end") sql_query = sql_query.replace("{{end}}", " and c.created_at < :end")
arg_dict["end"] = end_datetime_utc arg_dict["end"] = end_datetime_utc
else: else:
sql_query = sql_query.replace("{{end}}", "") sql_query = sql_query.replace("{{end}}", "")

View File

@ -8,7 +8,7 @@ from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import StrLen, email, timezone from libs.helper import email, str_len, timezone
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models.account import AccountStatus from models.account import AccountStatus
from services.account_service import RegisterService from services.account_service import RegisterService
@ -37,7 +37,7 @@ class ActivateApi(Resource):
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument( parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json" "interface_language", type=supported_language, required=True, nullable=False, location="json"

View File

@ -71,7 +71,7 @@ class OAuthCallback(Resource):
account = _generate_account(provider, user_info) account = _generate_account(provider, user_info)
# Check account status # Check account status
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
return {"error": "Account is banned or closed."}, 403 return {"error": "Account is banned or closed."}, 403
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account: if not account:
# Create account # Create account
account_name = user_info.name or "Dify" account_name = user_info.name if user_info.name else "Dify"
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
) )

View File

@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
@ -122,7 +122,6 @@ class DatasetListApi(Resource):
name=args["name"], name=args["name"],
indexing_technique=args["indexing_technique"], indexing_technique=args["indexing_technique"],
account=current_user, account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -399,7 +398,7 @@ class DatasetIndexingEstimateApi(Resource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -550,7 +549,12 @@ class DatasetApiBaseUrlApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} return {
"api_base_url": (
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
)
+ "/v1"
}
class DatasetRetrievalSettingApi(Resource): class DatasetRetrievalSettingApi(Resource):

View File

@ -302,8 +302,6 @@ class DatasetInitApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "doc_language", type=str, default="English", required=False, nullable=False, location="json"
) )
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -311,8 +309,6 @@ class DatasetInitApi(Resource):
raise Forbidden() raise Forbidden()
if args["indexing_technique"] == "high_quality": if args["indexing_technique"] == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_default_model_instance( model_manager.get_default_model_instance(
@ -354,7 +350,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
if document.indexing_status in {"completed", "error"}: if document.indexing_status in ["completed", "error"]:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_process_rule = document.dataset_process_rule data_process_rule = document.dataset_process_rule
@ -421,7 +417,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
info_list = [] info_list = []
extract_settings = [] extract_settings = []
for document in documents: for document in documents:
if document.indexing_status in {"completed", "error"}: if document.indexing_status in ["completed", "error"]:
raise DocumentAlreadyFinishedError() raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
# format document files info # format document files info
@ -665,7 +661,7 @@ class DocumentProcessingApi(DocumentResource):
db.session.commit() db.session.commit()
elif action == "resume": elif action == "resume":
if document.indexing_status not in {"paused", "error"}: if document.indexing_status not in ["paused", "error"]:
raise InvalidActionError("Document not in paused or error state.") raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None document.paused_by = None

View File

@ -39,7 +39,7 @@ class FileApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(file_fields) @marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents") @cloud_edition_billing_resource_check(resource="documents")
def post(self): def post(self):
# get file from request # get file from request
file = request.files["file"] file = request.files["file"]

View File

@ -18,7 +18,9 @@ class NotSetupError(BaseHTTPException):
class NotInitValidateError(BaseHTTPException): class NotInitValidateError(BaseHTTPException):
error_code = "not_init_validated" error_code = "not_init_validated"
description = "Init validation has not been completed yet. Please proceed with the init validation process first." description = (
"Init validation has not been completed yet. " "Please proceed with the init validation process first."
)
code = 401 code = 401

View File

@ -81,15 +81,19 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice") voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else: else:
try: try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)

View File

@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)

View File

@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource):
def get(self, installed_app): def get(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource):
def delete(self, installed_app, c_id): def delete(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource):
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -31,11 +31,10 @@ class InstalledAppsListApi(Resource):
"app_owner_tenant_id": installed_app.app_owner_tenant_id, "app_owner_tenant_id": installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned, "is_pinned": installed_app.is_pinned,
"last_used_at": installed_app.last_used_at, "last_used_at": installed_app.last_used_at,
"editable": current_user.role in {"owner", "admin"}, "editable": current_user.role in ["owner", "admin"],
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
} }
for installed_app in installed_apps for installed_app in installed_apps
if installed_app.app is not None
] ]
installed_apps.sort( installed_apps.sort(
key=lambda app: ( key=lambda app: (

View File

@ -40,7 +40,7 @@ class MessageListApi(InstalledAppResource):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
message_id = str(message_id) message_id = str(message_id)

View File

@ -43,7 +43,7 @@ class AppParameterApi(InstalledAppResource):
"""Retrieve app parameters.""" """Retrieve app parameters."""
app_model = installed_app.app app_model = installed_app.app
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -4,7 +4,7 @@ from flask import session
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config from configs import dify_config
from libs.helper import StrLen from libs.helper import str_len
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
raise AlreadySetupError() raise AlreadySetupError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("password", type=StrLen(30), required=True, location="json") parser.add_argument("password", type=str_len(30), required=True, location="json")
input_password = parser.parse_args()["password"] input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"): if input_password != os.environ.get("INIT_PASSWORD"):

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config from configs import dify_config
from libs.helper import StrLen, email, get_remote_ip from libs.helper import email, get_remote_ip, str_len
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup from models.model import DifySetup
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
@ -40,7 +40,7 @@ class SetupApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=StrLen(30), required=True, location="json") parser.add_argument("name", type=str_len(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()

View File

@ -13,7 +13,7 @@ from services.tag_service import TagService
def _validate_name(name): def _validate_name(name):
if not name or len(name) < 1 or len(name) > 50: if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 50 characters.") raise ValueError("Name must be between 1 to 50 characters.")
return name return name

View File

@ -218,7 +218,7 @@ api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-provider
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate") api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>") api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>")
api.add_resource( api.add_resource(
ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/<string:icon_type>/<string:lang>" ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>"
) )
api.add_resource( api.add_resource(

View File

@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):
return ApiToolManageService.test_api_tool_preview( return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id, current_user.current_tenant_id,
args["provider_name"] or "", args["provider_name"] if args["provider_name"] else "",
args["tool_name"], args["tool_name"],
args["credentials"], args["credentials"],
args["parameters"], args["parameters"],

View File

@ -194,7 +194,7 @@ class WebappLogoWorkspaceApi(Resource):
raise TooManyFilesError() raise TooManyFilesError()
extension = file.filename.split(".")[-1] extension = file.filename.split(".")[-1]
if extension.lower() not in {"svg", "png"}: if extension.lower() not in ["svg", "png"]:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
try: try:

View File

@ -46,7 +46,9 @@ def only_edition_self_hosted(view):
return decorated return decorated
def cloud_edition_billing_resource_check(resource: str): def cloud_edition_billing_resource_check(
resource: str, error_msg: str = "You have reached the limit of your subscription."
):
def interceptor(view): def interceptor(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
@ -58,23 +60,22 @@ def cloud_edition_billing_resource_check(resource: str):
documents_upload_quota = features.documents_upload_quota documents_upload_quota = features.documents_upload_quota
annotation_quota_limit = features.annotation_quota_limit annotation_quota_limit = features.annotation_quota_limit
if resource == "members" and 0 < members.limit <= members.size: if resource == "members" and 0 < members.limit <= members.size:
abort(403, "The number of members has reached the limit of your subscription.") abort(403, error_msg)
elif resource == "apps" and 0 < apps.limit <= apps.size: elif resource == "apps" and 0 < apps.limit <= apps.size:
abort(403, "The number of apps has reached the limit of your subscription.") abort(403, error_msg)
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(403, "The capacity of the vector space has reached the limit of your subscription.") abort(403, error_msg)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places, # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
# so we need to check the source of the request from datasets
source = request.args.get("source") source = request.args.get("source")
if source == "datasets": if source == "datasets":
abort(403, "The number of documents has reached the limit of your subscription.") abort(403, error_msg)
else: else:
return view(*args, **kwargs) return view(*args, **kwargs)
elif resource == "workspace_custom" and not features.can_replace_logo: elif resource == "workspace_custom" and not features.can_replace_logo:
abort(403, "The workspace custom feature has reached the limit of your subscription.") abort(403, error_msg)
elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
abort(403, "The annotation quota has reached the limit of your subscription.") abort(403, error_msg)
else: else:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -85,7 +86,10 @@ def cloud_edition_billing_resource_check(resource: str):
return interceptor return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str): def cloud_edition_billing_knowledge_limit_check(
resource: str,
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
):
def interceptor(view): def interceptor(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
@ -93,10 +97,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
if features.billing.enabled: if features.billing.enabled:
if resource == "add_segment": if resource == "add_segment":
if features.billing.subscription.plan == "sandbox": if features.billing.subscription.plan == "sandbox":
abort( abort(403, error_msg)
403,
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
)
else: else:
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@ -42,7 +42,7 @@ class AppParameterApi(Resource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App): def get(self, app_model: App):
"""Retrieve app parameters.""" """Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -79,15 +79,19 @@ class TextApi(Resource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice") voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else: else:
try: try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts( response = AudioService.transcript_tts(

View File

@ -96,7 +96,7 @@ class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser): def post(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -144,7 +144,7 @@ class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id): def post(self, app_model: App, end_user: EndUser, task_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationApi(Resource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -52,7 +52,7 @@ class ConversationDetailApi(Resource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def delete(self, app_model: App, end_user: EndUser, c_id): def delete(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -69,7 +69,7 @@ class ConversationRenameApi(Resource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model: App, end_user: EndUser, c_id): def post(self, app_model: App, end_user: EndUser, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -76,7 +76,7 @@ class MessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model: App, end_user: EndUser): def get(self, app_model: App, end_user: EndUser):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource):
def get(self, app_model: App, end_user: EndUser, message_id): def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id) message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
try: try:

View File

@ -1,7 +1,6 @@
import logging import logging
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.service_api import api from controllers.service_api import api
@ -23,12 +22,10 @@ from core.errors.error import (
) )
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper from libs import helper
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun from models.workflow import WorkflowRun
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -116,30 +113,6 @@ class WorkflowTaskStopApi(Resource):
return {"result": "success"} return {"result": "success"}
class WorkflowAppLogApi(Resource):
@validate_app_token
@marshal_with(workflow_app_log_pagination_fields)
def get(self, app_model: App):
"""
Get workflow app logs
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model, args=args
)
return workflow_app_log_pagination
api.add_resource(WorkflowRunApi, "/workflows/run") api.add_resource(WorkflowRunApi, "/workflows/run")
api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>") api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>")
api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop") api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop")
api.add_resource(WorkflowAppLogApi, "/workflows/logs")

View File

@ -36,10 +36,6 @@ class SegmentApi(DatasetApiResource):
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
if document.indexing_status != "completed":
raise NotFound("Document is not completed.")
if not document.enabled:
raise NotFound("Document is disabled.")
# check embedding model setting # check embedding model setting
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
try: try:
@ -67,7 +63,7 @@ class SegmentApi(DatasetApiResource):
segments = SegmentService.multi_create_segment(args["segments"], document, dataset) segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else: else:
return {"error": "Segments is required"}, 400 return {"error": "Segemtns is required"}, 400
def get(self, tenant_id, dataset_id, document_id): def get(self, tenant_id, dataset_id, document_id):
"""Create single segment.""" """Create single segment."""

View File

@ -83,7 +83,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
return decorator(view) return decorator(view)
def cloud_edition_billing_resource_check(resource: str, api_token_type: str): def cloud_edition_billing_resource_check(
resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription."
):
def interceptor(view): def interceptor(view):
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type) api_token = validate_and_get_api_token(api_token_type)
@ -96,13 +98,13 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
documents_upload_quota = features.documents_upload_quota documents_upload_quota = features.documents_upload_quota
if resource == "members" and 0 < members.limit <= members.size: if resource == "members" and 0 < members.limit <= members.size:
raise Forbidden("The number of members has reached the limit of your subscription.") raise Forbidden(error_msg)
elif resource == "apps" and 0 < apps.limit <= apps.size: elif resource == "apps" and 0 < apps.limit <= apps.size:
raise Forbidden("The number of apps has reached the limit of your subscription.") raise Forbidden(error_msg)
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.") raise Forbidden(error_msg)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
raise Forbidden("The number of documents has reached the limit of your subscription.") raise Forbidden(error_msg)
else: else:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -113,7 +115,11 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
return interceptor return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str): def cloud_edition_billing_knowledge_limit_check(
resource: str,
api_token_type: str,
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
):
def interceptor(view): def interceptor(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
@ -122,9 +128,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
if features.billing.enabled: if features.billing.enabled:
if resource == "add_segment": if resource == "add_segment":
if features.billing.subscription.plan == "sandbox": if features.billing.subscription.plan == "sandbox":
raise Forbidden( raise Forbidden(error_msg)
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
)
else: else:
return view(*args, **kwargs) return view(*args, **kwargs)

View File

@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource):
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
def get(self, app_model: App, end_user): def get(self, app_model: App, end_user):
"""Retrieve app parameters.""" """Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]:
workflow = app_model.workflow workflow = app_model.workflow
if workflow is None: if workflow is None:
raise AppUnavailableError() raise AppUnavailableError()

View File

@ -78,15 +78,19 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None) message_id = args.get("message_id", None)
text = args.get("text", None) text = args.get("text", None)
if ( if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice") voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
else: else:
try: try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
except Exception: except Exception:
voice = None voice = None

View File

@ -87,7 +87,7 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource): class ChatApi(WebApiResource):
def post(self, app_model, end_user): def post(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -136,7 +136,7 @@ class ChatApi(WebApiResource):
class ChatStopApi(WebApiResource): class ChatStopApi(WebApiResource):
def post(self, app_model, end_user, task_id): def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)

View File

@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -56,7 +56,7 @@ class ConversationListApi(WebApiResource):
class ConversationApi(WebApiResource): class ConversationApi(WebApiResource):
def delete(self, app_model, end_user, c_id): def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id): def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -92,7 +92,7 @@ class ConversationRenameApi(WebApiResource):
class ConversationPinApi(WebApiResource): class ConversationPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id): def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)
@ -108,7 +108,7 @@ class ConversationPinApi(WebApiResource):
class ConversationUnPinApi(WebApiResource): class ConversationUnPinApi(WebApiResource):
def patch(self, app_model, end_user, c_id): def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
conversation_id = str(c_id) conversation_id = str(c_id)

View File

@ -78,7 +78,7 @@ class MessageListApi(WebApiResource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user): def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -160,7 +160,7 @@ class MessageMoreLikeThisApi(WebApiResource):
class MessageSuggestedQuestionApi(WebApiResource): class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id): def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]:
raise NotCompletionAppError() raise NotCompletionAppError()
message_id = str(message_id) message_id = str(message_id)

View File

@ -39,7 +39,6 @@ class AppSiteApi(WebApiResource):
"default_language": fields.String, "default_language": fields.String,
"prompt_public": fields.Boolean, "prompt_public": fields.Boolean,
"show_workflow_steps": fields.Boolean, "show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
} }
app_fields = { app_fields = {

View File

@ -80,8 +80,7 @@ def _validate_web_sso_token(decoded, system_features, app_code):
if not source or source != "sso": if not source or source != "sso":
raise WebSSOAuthRequiredError() raise WebSSOAuthRequiredError()
# Check if SSO is not enforced for web, and if the token source is SSO, # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
# raise an error and redirect to normal passport login
if not system_features.sso_enforced_for_web or not app_web_sso_enabled: if not system_features.sso_enforced_for_web or not app_web_sso_enabled:
source = decoded.get("token_source") source = decoded.get("token_source")
if source and source == "sso": if source and source == "sso":

View File

@ -1 +1 @@
import core.moderation.base import core.moderation.base

View File

@ -1,7 +1,6 @@
import json import json
import logging import logging
import uuid import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional, Union, cast from typing import Optional, Union, cast
@ -46,25 +45,22 @@ from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner): class BaseAgentRunner(AppRunner):
def __init__( def __init__(self, tenant_id: str,
self, application_generate_entity: AgentChatAppGenerateEntity,
tenant_id: str, conversation: Conversation,
application_generate_entity: AgentChatAppGenerateEntity, app_config: AgentChatAppConfig,
conversation: Conversation, model_config: ModelConfigWithCredentialsEntity,
app_config: AgentChatAppConfig, config: AgentEntity,
model_config: ModelConfigWithCredentialsEntity, queue_manager: AppQueueManager,
config: AgentEntity, message: Message,
queue_manager: AppQueueManager, user_id: str,
message: Message, memory: Optional[TokenBufferMemory] = None,
user_id: str, prompt_messages: Optional[list[PromptMessage]] = None,
memory: Optional[TokenBufferMemory] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None,
prompt_messages: Optional[list[PromptMessage]] = None, db_variables: Optional[ToolConversationVariables] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None, model_instance: ModelInstance = None
db_variables: Optional[ToolConversationVariables] = None, ) -> None:
model_instance: ModelInstance = None,
) -> None:
""" """
Agent runner Agent runner
:param tenant_id: tenant id :param tenant_id: tenant id
@ -92,7 +88,9 @@ class BaseAgentRunner(AppRunner):
self.message = message self.message = message
self.user_id = user_id self.user_id = user_id
self.memory = memory self.memory = memory
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or []) self.history_prompt_messages = self.organize_agent_history(
prompt_messages=prompt_messages or []
)
self.variables_pool = variables_pool self.variables_pool = variables_pool
self.db_variables_pool = db_variables self.db_variables_pool = db_variables
self.model_instance = model_instance self.model_instance = model_instance
@ -113,16 +111,12 @@ class BaseAgentRunner(AppRunner):
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None, retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source, return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback, hit_callback=hit_callback
) )
# get how many agent thoughts have been created # get how many agent thoughts have been created
self.agent_thought_count = ( self.agent_thought_count = db.session.query(MessageAgentThought).filter(
db.session.query(MessageAgentThought) MessageAgentThought.message_id == self.message.id,
.filter( ).count()
MessageAgentThought.message_id == self.message.id,
)
.count()
)
db.session.close() db.session.close()
# check if model supports stream tool call # check if model supports stream tool call
@ -141,26 +135,25 @@ class BaseAgentRunner(AppRunner):
self.query = None self.query = None
self._current_thoughts: list[PromptMessage] = [] self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity( def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
self, app_generate_entity: AgentChatAppGenerateEntity -> AgentChatAppGenerateEntity:
) -> AgentChatAppGenerateEntity:
""" """
Repack app generate entity Repack app generate entity
""" """
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None: if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
app_generate_entity.app_config.prompt_template.simple_prompt_template = "" app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
return app_generate_entity return app_generate_entity
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
""" """
convert tool to prompt message tool convert tool to prompt message tool
""" """
tool_entity = ToolManager.get_agent_tool_runtime( tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
app_id=self.app_config.app_id, app_id=self.app_config.app_id,
agent_tool=tool, agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from
) )
tool_entity.load_variables(self.variables_pool) tool_entity.load_variables(self.variables_pool)
@ -171,7 +164,7 @@ class BaseAgentRunner(AppRunner):
"type": "object", "type": "object",
"properties": {}, "properties": {},
"required": [], "required": [],
}, }
) )
parameters = tool_entity.get_all_runtime_parameters() parameters = tool_entity.get_all_runtime_parameters()
@ -184,19 +177,19 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT: if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] enum = [option.value for option in parameter.options]
message_tool.parameters["properties"][parameter.name] = { message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or "", "description": parameter.llm_description or '',
} }
if len(enum) > 0: if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum message_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required: if parameter.required:
message_tool.parameters["required"].append(parameter.name) message_tool.parameters['required'].append(parameter.name)
return message_tool, tool_entity return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool: def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
""" """
convert dataset retriever tool to prompt message tool convert dataset retriever tool to prompt message tool
@ -208,24 +201,24 @@ class BaseAgentRunner(AppRunner):
"type": "object", "type": "object",
"properties": {}, "properties": {},
"required": [], "required": [],
}, }
) )
for parameter in tool.get_runtime_parameters(): for parameter in tool.get_runtime_parameters():
parameter_type = "string" parameter_type = 'string'
prompt_tool.parameters["properties"][parameter.name] = { prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or "", "description": parameter.llm_description or '',
} }
if parameter.required: if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]: if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters["required"].append(parameter.name) prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool return prompt_tool
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]: def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
""" """
Init tools Init tools
""" """
@ -268,51 +261,51 @@ class BaseAgentRunner(AppRunner):
enum = [] enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT: if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] enum = [option.value for option in parameter.options]
prompt_tool.parameters["properties"][parameter.name] = { prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type, "type": parameter_type,
"description": parameter.llm_description or "", "description": parameter.llm_description or '',
} }
if len(enum) > 0: if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required: if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]: if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters["required"].append(parameter.name) prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool return prompt_tool
def create_agent_thought( def create_agent_thought(self, message_id: str, message: str,
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought: ) -> MessageAgentThought:
""" """
Create agent thought Create agent thought
""" """
thought = MessageAgentThought( thought = MessageAgentThought(
message_id=message_id, message_id=message_id,
message_chain_id=None, message_chain_id=None,
thought="", thought='',
tool=tool_name, tool=tool_name,
tool_labels_str="{}", tool_labels_str='{}',
tool_meta_str="{}", tool_meta_str='{}',
tool_input=tool_input, tool_input=tool_input,
message=message, message=message,
message_token=0, message_token=0,
message_unit_price=0, message_unit_price=0,
message_price_unit=0, message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else "", message_files=json.dumps(messages_ids) if messages_ids else '',
answer="", answer='',
observation="", observation='',
answer_token=0, answer_token=0,
answer_unit_price=0, answer_unit_price=0,
answer_price_unit=0, answer_price_unit=0,
tokens=0, tokens=0,
total_price=0, total_price=0,
position=self.agent_thought_count + 1, position=self.agent_thought_count + 1,
currency="USD", currency='USD',
latency=0, latency=0,
created_by_role="account", created_by_role='account',
created_by=self.user_id, created_by=self.user_id,
) )
@ -325,22 +318,22 @@ class BaseAgentRunner(AppRunner):
return thought return thought
def save_agent_thought( def save_agent_thought(self,
self, agent_thought: MessageAgentThought,
agent_thought: MessageAgentThought, tool_name: str,
tool_name: str, tool_input: Union[str, dict],
tool_input: Union[str, dict], thought: str,
thought: str, observation: Union[str, dict],
observation: Union[str, dict], tool_invoke_meta: Union[str, dict],
tool_invoke_meta: Union[str, dict], answer: str,
answer: str, messages_ids: list[str],
messages_ids: list[str], llm_usage: LLMUsage = None) -> MessageAgentThought:
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
""" """
Save agent thought Save agent thought
""" """
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
if thought is not None: if thought is not None:
agent_thought.thought = thought agent_thought.thought = thought
@ -363,7 +356,7 @@ class BaseAgentRunner(AppRunner):
observation = json.dumps(observation, ensure_ascii=False) observation = json.dumps(observation, ensure_ascii=False)
except Exception as e: except Exception as e:
observation = json.dumps(observation) observation = json.dumps(observation)
agent_thought.observation = observation agent_thought.observation = observation
if answer is not None: if answer is not None:
@ -371,7 +364,7 @@ class BaseAgentRunner(AppRunner):
if messages_ids is not None and len(messages_ids) > 0: if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids) agent_thought.message_files = json.dumps(messages_ids)
if llm_usage: if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit agent_thought.message_price_unit = llm_usage.prompt_price_unit
@ -384,7 +377,7 @@ class BaseAgentRunner(AppRunner):
# check if tool labels is not empty # check if tool labels is not empty
labels = agent_thought.tool_labels or {} labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(";") if agent_thought.tool else [] tools = agent_thought.tool.split(';') if agent_thought.tool else []
for tool in tools: for tool in tools:
if not tool: if not tool:
continue continue
@ -393,7 +386,7 @@ class BaseAgentRunner(AppRunner):
if tool_label: if tool_label:
labels[tool] = tool_label.to_dict() labels[tool] = tool_label.to_dict()
else: else:
labels[tool] = {"en_US": tool, "zh_Hans": tool} labels[tool] = {'en_US': tool, 'zh_Hans': tool}
agent_thought.tool_labels_str = json.dumps(labels) agent_thought.tool_labels_str = json.dumps(labels)
@ -408,18 +401,14 @@ class BaseAgentRunner(AppRunner):
db.session.commit() db.session.commit()
db.session.close() db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables): def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
""" """
convert tool variables to db variables convert tool variables to db variables
""" """
db_variables = ( db_variables = db.session.query(ToolConversationVariables).filter(
db.session.query(ToolConversationVariables) ToolConversationVariables.conversation_id == self.message.conversation_id,
.filter( ).first()
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
@ -436,14 +425,9 @@ class BaseAgentRunner(AppRunner):
if isinstance(prompt_message, SystemPromptMessage): if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message) result.append(prompt_message)
messages: list[Message] = ( messages: list[Message] = db.session.query(Message).filter(
db.session.query(Message) Message.conversation_id == self.message.conversation_id,
.filter( ).order_by(Message.created_at.asc()).all()
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.asc())
.all()
)
for message in messages: for message in messages:
if message.id == self.message.id: if message.id == self.message.id:
@ -455,13 +439,13 @@ class BaseAgentRunner(AppRunner):
for agent_thought in agent_thoughts: for agent_thought in agent_thoughts:
tools = agent_thought.tool tools = agent_thought.tool
if tools: if tools:
tools = tools.split(";") tools = tools.split(';')
tool_calls: list[AssistantPromptMessage.ToolCall] = [] tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = [] tool_call_response: list[ToolPromptMessage] = []
try: try:
tool_inputs = json.loads(agent_thought.tool_input) tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e: except Exception as e:
tool_inputs = {tool: {} for tool in tools} tool_inputs = { tool: {} for tool in tools }
try: try:
tool_responses = json.loads(agent_thought.observation) tool_responses = json.loads(agent_thought.observation)
except Exception as e: except Exception as e:
@ -470,33 +454,27 @@ class BaseAgentRunner(AppRunner):
for tool in tools: for tool in tools:
# generate a uuid for tool call # generate a uuid for tool call
tool_call_id = str(uuid.uuid4()) tool_call_id = str(uuid.uuid4())
tool_calls.append( tool_calls.append(AssistantPromptMessage.ToolCall(
AssistantPromptMessage.ToolCall( id=tool_call_id,
id=tool_call_id, type='function',
type="function", function=AssistantPromptMessage.ToolCall.ToolCallFunction(
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
),
)
)
tool_call_response.append(
ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
name=tool, name=tool,
tool_call_id=tool_call_id, arguments=json.dumps(tool_inputs.get(tool, {})),
) )
) ))
tool_call_response.append(ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
name=tool,
tool_call_id=tool_call_id,
))
result.extend( result.extend([
[ AssistantPromptMessage(
AssistantPromptMessage( content=agent_thought.thought,
content=agent_thought.thought, tool_calls=tool_calls,
tool_calls=tool_calls, ),
), *tool_call_response
*tool_call_response, ])
]
)
if not tools: if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought)) result.append(AssistantPromptMessage(content=agent_thought.thought))
else: else:
@ -518,7 +496,10 @@ class BaseAgentRunner(AppRunner):
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.transform_message_files(files, file_extra_config) file_objs = message_file_parser.transform_message_files(
files,
file_extra_config
)
else: else:
file_objs = [] file_objs = []

View File

@ -25,19 +25,17 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC): class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True _is_first_iteration = True
_ignore_observation_providers = ["wenxin"] _ignore_observation_providers = ['wenxin']
_historic_prompt_messages: list[PromptMessage] = None _historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None _agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None _instruction: str = None
_query: str = None _query: str = None
_prompt_messages_tools: list[PromptMessage] = None _prompt_messages_tools: list[PromptMessage] = None
def run( def run(self, message: Message,
self, query: str,
message: Message, inputs: dict[str, str],
query: str, ) -> Union[Generator, LLMResult]:
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
""" """
Run Cot agent application Run Cot agent application
""" """
@ -48,16 +46,17 @@ class CotAgentRunner(BaseAgentRunner, ABC):
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager
# check model mode # check model mode
if "Observation" not in app_generate_entity.model_conf.stop: if 'Observation' not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers: if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append("Observation") app_generate_entity.model_conf.stop.append('Observation')
app_config = self.app_config app_config = self.app_config
# init instruction # init instruction
inputs = inputs or {} inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) self._instruction = self._fill_in_inputs_from_external_data_tools(
instruction, inputs)
iteration_step = 1 iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@ -66,14 +65,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools() tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True function_call_state = True
llm_usage = {"usage": None} llm_usage = {
final_answer = "" 'usage': None
}
final_answer = ''
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict["usage"]: if not final_llm_usage_dict['usage']:
final_llm_usage_dict["usage"] = usage final_llm_usage_dict['usage'] = usage
else: else:
llm_usage = final_llm_usage_dict["usage"] llm_usage = final_llm_usage_dict['usage']
llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
@ -93,13 +94,17 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message_file_ids = [] message_file_ids = []
agent_thought = self.create_agent_thought( agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
) )
if iteration_step > 1: if iteration_step > 1:
self.queue_manager.publish( self.queue_manager.publish(QueueAgentThoughtEvent(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER agent_thought_id=agent_thought.id
) ), PublishFrom.APPLICATION_MANAGER)
# recalc llm max tokens # recalc llm max tokens
prompt_messages = self._organize_prompt_messages() prompt_messages = self._organize_prompt_messages()
@ -120,20 +125,21 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm") raise ValueError("failed to invoke llm")
usage_dict = {} usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict) react_chunks = CotAgentOutputParser.handle_react_stream_output(
chunks, usage_dict)
scratchpad = AgentScratchpadUnit( scratchpad = AgentScratchpadUnit(
agent_response="", agent_response='',
thought="", thought='',
action_str="", action_str='',
observation="", observation='',
action=None, action=None,
) )
# publish agent thought if it's first iteration # publish agent thought if it's first iteration
if iteration_step == 1: if iteration_step == 1:
self.queue_manager.publish( self.queue_manager.publish(QueueAgentThoughtEvent(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER agent_thought_id=agent_thought.id
) ), PublishFrom.APPLICATION_MANAGER)
for chunk in react_chunks: for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action): if isinstance(chunk, AgentScratchpadUnit.Action):
@ -148,51 +154,61 @@ class CotAgentRunner(BaseAgentRunner, ABC):
yield LLMResultChunk( yield LLMResultChunk(
model=self.model_config.model, model=self.model_config.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
system_fingerprint="", system_fingerprint='',
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None), delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
) )
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you" scratchpad.thought = scratchpad.thought.strip(
) or 'I am thinking about how to help you'
self._agent_scratchpad.append(scratchpad) self._agent_scratchpad.append(scratchpad)
# get llm usage # get llm usage
if "usage" in usage_dict: if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict["usage"]) increase_usage(llm_usage, usage_dict['usage'])
else: else:
usage_dict["usage"] = LLMUsage.empty_usage() usage_dict['usage'] = LLMUsage.empty_usage()
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else "", tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_invoke_meta={}, tool_invoke_meta={},
thought=scratchpad.thought, thought=scratchpad.thought,
observation="", observation='',
answer=scratchpad.agent_response, answer=scratchpad.agent_response,
messages_ids=[], messages_ids=[],
llm_usage=usage_dict["usage"], llm_usage=usage_dict['usage']
) )
if not scratchpad.is_final(): if not scratchpad.is_final():
self.queue_manager.publish( self.queue_manager.publish(QueueAgentThoughtEvent(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER agent_thought_id=agent_thought.id
) ), PublishFrom.APPLICATION_MANAGER)
if not scratchpad.action: if not scratchpad.action:
# failed to extract action, return final answer directly # failed to extract action, return final answer directly
final_answer = "" final_answer = ''
else: else:
if scratchpad.action.action_name.lower() == "final answer": if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly # action is final answer, return final answer directly
try: try:
if isinstance(scratchpad.action.action_input, dict): if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input) final_answer = json.dumps(
scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str): elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input final_answer = scratchpad.action.action_input
else: else:
final_answer = f"{scratchpad.action.action_input}" final_answer = f'{scratchpad.action.action_input}'
except json.JSONDecodeError: except json.JSONDecodeError:
final_answer = f"{scratchpad.action.action_input}" final_answer = f'{scratchpad.action.action_input}'
else: else:
function_call_state = True function_call_state = True
# action is tool call, invoke tool # action is tool call, invoke tool
@ -208,18 +224,21 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=scratchpad.action.action_name, tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input}, tool_input={
scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought, thought=scratchpad.thought,
observation={scratchpad.action.action_name: tool_invoke_response}, observation={
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()}, scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response, answer=scratchpad.agent_response,
messages_ids=message_file_ids, messages_ids=message_file_ids,
llm_usage=usage_dict["usage"], llm_usage=usage_dict['usage']
) )
self.queue_manager.publish( self.queue_manager.publish(QueueAgentThoughtEvent(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER agent_thought_id=agent_thought.id
) ), PublishFrom.APPLICATION_MANAGER)
# update prompt tool message # update prompt tool message
for prompt_tool in self._prompt_messages_tools: for prompt_tool in self._prompt_messages_tools:
@ -231,45 +250,44 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model, model=model_instance.model,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
delta=LLMResultChunkDelta( delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] index=0,
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
), ),
system_fingerprint="", system_fingerprint=''
) )
# save agent thought # save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name="", tool_name='',
tool_input={}, tool_input={},
tool_invoke_meta={}, tool_invoke_meta={},
thought=final_answer, thought=final_answer,
observation={}, observation={},
answer=final_answer, answer=final_answer,
messages_ids=[], messages_ids=[]
) )
self.update_db_variables(self.variables_pool, self.db_variables_pool) self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish( self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
QueueMessageEndEvent( model=model_instance.model,
llm_result=LLMResult( prompt_messages=prompt_messages,
model=model_instance.model, message=AssistantPromptMessage(
prompt_messages=prompt_messages, content=final_answer
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
), ),
PublishFrom.APPLICATION_MANAGER, usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
) system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
def _handle_invoke_action( def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
self, tool_instances: dict[str, Tool],
action: AgentScratchpadUnit.Action, message_file_ids: list[str],
tool_instances: dict[str, Tool], trace_manager: Optional[TraceQueueManager] = None
message_file_ids: list[str], ) -> tuple[str, ToolInvokeMeta]:
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
""" """
handle invoke action handle invoke action
:param action: action :param action: action
@ -308,12 +326,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files # publish files
for message_file_id, save_as in message_files: for message_file_id, save_as in message_files:
if save_as: if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file # publish message file
self.queue_manager.publish( self.queue_manager.publish(QueueMessageFileEvent(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER message_file_id=message_file_id
) ), PublishFrom.APPLICATION_MANAGER)
# add message file ids # add message file ids
message_file_ids.append(message_file_id) message_file_ids.append(message_file_id)
@ -323,7 +342,10 @@ class CotAgentRunner(BaseAgentRunner, ABC):
""" """
convert dict to action convert dict to action
""" """
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"]) return AgentScratchpadUnit.Action(
action_name=action['action'],
action_input=action['action_input']
)
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
""" """
@ -331,7 +353,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
""" """
for key, value in inputs.items(): for key, value in inputs.items():
try: try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value)) instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
except Exception as e: except Exception as e:
continue continue
@ -348,14 +370,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
@abstractmethod @abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]: def _organize_prompt_messages(self) -> list[PromptMessage]:
""" """
organize prompt messages organize prompt messages
""" """
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
""" """
format assistant message format assistant message
""" """
message = "" message = ''
for scratchpad in agent_scratchpad: for scratchpad in agent_scratchpad:
if scratchpad.is_final(): if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}" message += f"Final Answer: {scratchpad.agent_response}"
@ -368,11 +390,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message return message
def _organize_historic_prompt_messages( def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
self, current_session_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
""" """
organize historic prompt messages organize historic prompt messages
""" """
result: list[PromptMessage] = [] result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = [] scratchpads: list[AgentScratchpadUnit] = []
@ -383,8 +403,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not current_scratchpad: if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit( current_scratchpad = AgentScratchpadUnit(
agent_response=message.content, agent_response=message.content,
thought=message.content or "I am thinking about how to help you", thought=message.content or 'I am thinking about how to help you',
action_str="", action_str='',
action=None, action=None,
observation=None, observation=None,
) )
@ -393,9 +413,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try: try:
current_scratchpad.action = AgentScratchpadUnit.Action( current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name, action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments), action_input=json.loads(
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
) )
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except: except:
pass pass
elif isinstance(message, ToolPromptMessage): elif isinstance(message, ToolPromptMessage):
@ -403,19 +426,23 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad.observation = message.content current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage): elif isinstance(message, UserPromptMessage):
if scratchpads: if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
scratchpads = [] scratchpads = []
current_scratchpad = None current_scratchpad = None
result.append(message) result.append(message)
if scratchpads: if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
historic_prompts = AgentHistoryPromptTransform( historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config, model_config=self.model_config,
prompt_messages=current_session_messages or [], prompt_messages=current_session_messages or [],
history_messages=result, history_messages=result,
memory=self.memory, memory=self.memory
).get_prompt() ).get_prompt()
return historic_prompts return historic_prompts

View File

@ -19,15 +19,14 @@ class CotChatAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt first_prompt = prompt_entity.first_prompt
system_prompt = ( system_prompt = first_prompt \
first_prompt.replace("{{instruction}}", self._instruction) .replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools])) .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt) return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
""" """
Organize user query Organize user query
""" """
@ -44,7 +43,7 @@ class CotChatAgentRunner(CotAgentRunner):
def _organize_prompt_messages(self) -> list[PromptMessage]: def _organize_prompt_messages(self) -> list[PromptMessage]:
""" """
Organize Organize
""" """
# organize system prompt # organize system prompt
system_message = self._organize_system_prompt() system_message = self._organize_system_prompt()
@ -54,7 +53,7 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad: if not agent_scratchpad:
assistant_messages = [] assistant_messages = []
else: else:
assistant_message = AssistantPromptMessage(content="") assistant_message = AssistantPromptMessage(content='')
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}" assistant_message.content += f"Final Answer: {unit.agent_response}"
@ -72,15 +71,18 @@ class CotChatAgentRunner(CotAgentRunner):
if assistant_messages: if assistant_messages:
# organize historic prompt messages # organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages( historic_messages = self._organize_historic_prompt_messages([
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")] system_message,
) *query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
messages = [ messages = [
system_message, system_message,
*historic_messages, *historic_messages,
*query_messages, *query_messages,
*assistant_messages, *assistant_messages,
UserPromptMessage(content="continue"), UserPromptMessage(content='continue')
] ]
else: else:
# organize historic prompt messages # organize historic prompt messages

View File

@ -13,12 +13,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt first_prompt = prompt_entity.first_prompt
system_prompt = ( system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
first_prompt.replace("{{instruction}}", self._instruction) .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str: def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
@ -48,7 +46,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
# organize current assistant messages # organize current assistant messages
agent_scratchpad = self._agent_scratchpad agent_scratchpad = self._agent_scratchpad
assistant_prompt = "" assistant_prompt = ''
for unit in agent_scratchpad: for unit in agent_scratchpad:
if unit.is_final(): if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}" assistant_prompt += f"Final Answer: {unit.agent_response}"
@ -63,10 +61,9 @@ class CotCompletionAgentRunner(CotAgentRunner):
query_prompt = f"Question: {self._query}" query_prompt = f"Question: {self._query}"
# join all messages # join all messages
prompt = ( prompt = system_prompt \
system_prompt.replace("{{historic_messages}}", historic_prompt) .replace("{{historic_messages}}", historic_prompt) \
.replace("{{agent_scratchpad}}", assistant_prompt) .replace("{{agent_scratchpad}}", assistant_prompt) \
.replace("{{query}}", query_prompt) .replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)] return [UserPromptMessage(content=prompt)]

View File

@ -8,7 +8,6 @@ class AgentToolEntity(BaseModel):
""" """
Agent Tool Entity. Agent Tool Entity.
""" """
provider_type: Literal["builtin", "api", "workflow"] provider_type: Literal["builtin", "api", "workflow"]
provider_id: str provider_id: str
tool_name: str tool_name: str
@ -19,7 +18,6 @@ class AgentPromptEntity(BaseModel):
""" """
Agent Prompt Entity. Agent Prompt Entity.
""" """
first_prompt: str first_prompt: str
next_iteration: str next_iteration: str
@ -33,7 +31,6 @@ class AgentScratchpadUnit(BaseModel):
""" """
Action Entity. Action Entity.
""" """
action_name: str action_name: str
action_input: Union[dict, str] action_input: Union[dict, str]
@ -42,8 +39,8 @@ class AgentScratchpadUnit(BaseModel):
Convert to dictionary. Convert to dictionary.
""" """
return { return {
"action": self.action_name, 'action': self.action_name,
"action_input": self.action_input, 'action_input': self.action_input,
} }
agent_response: Optional[str] = None agent_response: Optional[str] = None
@ -57,10 +54,10 @@ class AgentScratchpadUnit(BaseModel):
Check if the scratchpad unit is final. Check if the scratchpad unit is final.
""" """
return self.action is None or ( return self.action is None or (
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower() 'final' in self.action.action_name.lower() and
'answer' in self.action.action_name.lower()
) )
class AgentEntity(BaseModel): class AgentEntity(BaseModel):
""" """
Agent Entity. Agent Entity.
@ -70,9 +67,8 @@ class AgentEntity(BaseModel):
""" """
Agent Strategy. Agent Strategy.
""" """
CHAIN_OF_THOUGHT = 'chain-of-thought'
CHAIN_OF_THOUGHT = "chain-of-thought" FUNCTION_CALLING = 'function-calling'
FUNCTION_CALLING = "function-calling"
provider: str provider: str
model: str model: str

View File

@ -24,9 +24,11 @@ from models.model import Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner): class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
""" """
Run FunctionCall agent application Run FunctionCall agent application
""" """
@ -43,17 +45,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call # continue to run until there is not any tool call
function_call_state = True function_call_state = True
llm_usage = {"usage": None} llm_usage = {
final_answer = "" 'usage': None
}
final_answer = ''
# get tracing instance # get tracing instance
trace_manager = app_generate_entity.trace_manager trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict["usage"]: if not final_llm_usage_dict['usage']:
final_llm_usage_dict["usage"] = usage final_llm_usage_dict['usage'] = usage
else: else:
llm_usage = final_llm_usage_dict["usage"] llm_usage = final_llm_usage_dict['usage']
llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
@ -71,7 +75,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
message_file_ids = [] message_file_ids = []
agent_thought = self.create_agent_thought( agent_thought = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
) )
# recalc llm max tokens # recalc llm max tokens
@ -91,11 +99,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls: list[tuple[str, str, dict[str, Any]]] = [] tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response # save full response
response = "" response = ''
# save tool call names and inputs # save tool call names and inputs
tool_call_names = "" tool_call_names = ''
tool_call_inputs = "" tool_call_inputs = ''
current_llm_usage = None current_llm_usage = None
@ -103,22 +111,24 @@ class FunctionCallAgentRunner(BaseAgentRunner):
is_first_chunk = True is_first_chunk = True
for chunk in chunks: for chunk in chunks:
if is_first_chunk: if is_first_chunk:
self.queue_manager.publish( self.queue_manager.publish(QueueAgentThoughtEvent(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER agent_thought_id=agent_thought.id
) ), PublishFrom.APPLICATION_MANAGER)
is_first_chunk = False is_first_chunk = False
# check if there is any tool call # check if there is any tool call
if self.check_tool_calls(chunk): if self.check_tool_calls(chunk):
function_call_state = True function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk)) tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try: try:
tool_call_inputs = json.dumps( tool_call_inputs = json.dumps({
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False tool_call[1]: tool_call[2] for tool_call in tool_calls
) }, ensure_ascii=False)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if chunk.delta.message and chunk.delta.message.content: if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list): if isinstance(chunk.delta.message.content, list):
@ -138,14 +148,16 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.check_blocking_tool_calls(result): if self.check_blocking_tool_calls(result):
function_call_state = True function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result)) tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls]) tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
try: try:
tool_call_inputs = json.dumps( tool_call_inputs = json.dumps({
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False tool_call[1]: tool_call[2] for tool_call in tool_calls
) }, ensure_ascii=False)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error # ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls}) tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
if result.usage: if result.usage:
increase_usage(llm_usage, result.usage) increase_usage(llm_usage, result.usage)
@ -159,12 +171,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
response += result.message.content response += result.message.content
if not result.message.content: if not result.message.content:
result.message.content = "" result.message.content = ''
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk( yield LLMResultChunk(
model=model_instance.model, model=model_instance.model,
prompt_messages=result.prompt_messages, prompt_messages=result.prompt_messages,
@ -173,29 +185,32 @@ class FunctionCallAgentRunner(BaseAgentRunner):
index=0, index=0,
message=result.message, message=result.message,
usage=result.usage, usage=result.usage,
), )
) )
assistant_message = AssistantPromptMessage(content="", tool_calls=[]) assistant_message = AssistantPromptMessage(
content='',
tool_calls=[]
)
if tool_calls: if tool_calls:
assistant_message.tool_calls = [ assistant_message.tool_calls=[
AssistantPromptMessage.ToolCall( AssistantPromptMessage.ToolCall(
id=tool_call[0], id=tool_call[0],
type="function", type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction( function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False) name=tool_call[1],
), arguments=json.dumps(tool_call[2], ensure_ascii=False)
) )
for tool_call in tool_calls ) for tool_call in tool_calls
] ]
else: else:
assistant_message.content = response assistant_message.content = response
self._current_thoughts.append(assistant_message) self._current_thoughts.append(assistant_message)
# save thought # save thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=tool_call_names, tool_name=tool_call_names,
tool_input=tool_call_inputs, tool_input=tool_call_inputs,
thought=response, thought=response,
@ -203,13 +218,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
observation=None, observation=None,
answer=response, answer=response,
messages_ids=[], messages_ids=[],
llm_usage=current_llm_usage, llm_usage=current_llm_usage
) )
self.queue_manager.publish( self.queue_manager.publish(QueueAgentThoughtEvent(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER agent_thought_id=agent_thought.id
) ), PublishFrom.APPLICATION_MANAGER)
final_answer += response + "\n" final_answer += response + '\n'
# call tools # call tools
tool_responses = [] tool_responses = []
@ -220,7 +235,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"tool_call_name": tool_call_name, "tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}", "tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
} }
else: else:
# invoke tool # invoke tool
@ -240,49 +255,50 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as) self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file # publish message file
self.queue_manager.publish( self.queue_manager.publish(QueueMessageFileEvent(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER message_file_id=message_file_id
) ), PublishFrom.APPLICATION_MANAGER)
# add message file ids # add message file ids
message_file_ids.append(message_file_id) message_file_ids.append(message_file_id)
tool_response = { tool_response = {
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"tool_call_name": tool_call_name, "tool_call_name": tool_call_name,
"tool_response": tool_invoke_response, "tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict(), "meta": tool_invoke_meta.to_dict()
} }
tool_responses.append(tool_response) tool_responses.append(tool_response)
if tool_response["tool_response"] is not None: if tool_response['tool_response'] is not None:
self._current_thoughts.append( self._current_thoughts.append(
ToolPromptMessage( ToolPromptMessage(
content=tool_response["tool_response"], content=tool_response['tool_response'],
tool_call_id=tool_call_id, tool_call_id=tool_call_id,
name=tool_call_name, name=tool_call_name,
) )
) )
if len(tool_responses) > 0: if len(tool_responses) > 0:
# save agent thought # save agent thought
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=None, tool_name=None,
tool_input=None, tool_input=None,
thought=None, thought=None,
tool_invoke_meta={ tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses tool_response['tool_call_name']: tool_response['meta']
for tool_response in tool_responses
}, },
observation={ observation={
tool_response["tool_call_name"]: tool_response["tool_response"] tool_response['tool_call_name']: tool_response['tool_response']
for tool_response in tool_responses for tool_response in tool_responses
}, },
answer=None, answer=None,
messages_ids=message_file_ids, messages_ids=message_file_ids
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
) )
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool # update prompt tool
for prompt_tool in prompt_messages_tools: for prompt_tool in prompt_messages_tools:
@ -292,18 +308,15 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool) self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event # publish end event
self.queue_manager.publish( self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
QueueMessageEndEvent( model=model_instance.model,
llm_result=LLMResult( prompt_messages=prompt_messages,
model=model_instance.model, message=AssistantPromptMessage(
prompt_messages=prompt_messages, content=final_answer
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
), ),
PublishFrom.APPLICATION_MANAGER, usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
) system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
""" """
@ -312,7 +325,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if llm_result_chunk.delta.message.tool_calls: if llm_result_chunk.delta.message.tool_calls:
return True return True
return False return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
""" """
Check if there is any blocking tool call in llm result Check if there is any blocking tool call in llm result
@ -321,9 +334,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True return True
return False return False
def extract_tool_calls( def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
""" """
Extract tool calls from llm result chunk Extract tool calls from llm result chunk
@ -333,19 +344,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = [] tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls: for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {} args = {}
if prompt_message.function.arguments != "": if prompt_message.function.arguments != '':
args = json.loads(prompt_message.function.arguments) args = json.loads(prompt_message.function.arguments)
tool_calls.append( tool_calls.append((
( prompt_message.id,
prompt_message.id, prompt_message.function.name,
prompt_message.function.name, args,
args, ))
)
)
return tool_calls return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
""" """
Extract blocking tool calls from llm result Extract blocking tool calls from llm result
@ -356,22 +365,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = [] tool_calls = []
for prompt_message in llm_result.message.tool_calls: for prompt_message in llm_result.message.tool_calls:
args = {} args = {}
if prompt_message.function.arguments != "": if prompt_message.function.arguments != '':
args = json.loads(prompt_message.function.arguments) args = json.loads(prompt_message.function.arguments)
tool_calls.append( tool_calls.append((
( prompt_message.id,
prompt_message.id, prompt_message.function.name,
prompt_message.function.name, args,
args, ))
)
)
return tool_calls return tool_calls
def _init_system_message( def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
""" """
Initialize system message Initialize system message
""" """
@ -379,13 +384,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return [ return [
SystemPromptMessage(content=prompt_template), SystemPromptMessage(content=prompt_template),
] ]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template: if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template)) prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]: def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
""" """
Organize user query Organize user query
""" """
@ -399,7 +404,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_messages.append(UserPromptMessage(content=query)) prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
""" """
As for now, gpt supports both fc and vision at the first iteration. As for now, gpt supports both fc and vision at the first iteration.
@ -410,21 +415,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage): if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list): if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join( prompt_message.content = '\n'.join([
[ content.data if content.type == PromptMessageContentType.TEXT else
content.data '[image]' if content.type == PromptMessageContentType.IMAGE else
if content.type == PromptMessageContentType.TEXT '[file]'
else "[image]" for content in prompt_message.content
if content.type == PromptMessageContentType.IMAGE ])
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages return prompt_messages
def _organize_prompt_messages(self): def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or "" prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages) self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, []) query_prompt_messages = self._organize_user_query(self.query, [])
@ -432,10 +433,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_config=self.model_config, model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts], prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages, history_messages=self.history_prompt_messages,
memory=self.memory, memory=self.memory
).get_prompt() ).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts] prompt_messages = [
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
if len(self._current_thoughts) != 0: if len(self._current_thoughts) != 0:
# clear messages after the first iteration # clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages) prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

View File

@ -9,9 +9,8 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser: class CotAgentOutputParser:
@classmethod @classmethod
def handle_react_stream_output( def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str): def parse_action(json_str):
try: try:
action = json.loads(json_str) action = json.loads(json_str)
@ -23,7 +22,7 @@ class CotAgentOutputParser:
action = action[0] action = action[0]
for key, value in action.items(): for key, value in action.items():
if "input" in key.lower(): if 'input' in key.lower():
action_input = value action_input = value
else: else:
action_name = value action_name = value
@ -34,37 +33,37 @@ class CotAgentOutputParser:
action_input=action_input, action_input=action_input,
) )
else: else:
return json_str or "" return json_str or ''
except: except:
return json_str or "" return json_str or ''
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL) code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
if not code_blocks: if not code_blocks:
return return
for block in code_blocks: for block in code_blocks:
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE) json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
yield parse_action(json_text) yield parse_action(json_text)
code_block_cache = "" code_block_cache = ''
code_block_delimiter_count = 0 code_block_delimiter_count = 0
in_code_block = False in_code_block = False
json_cache = "" json_cache = ''
json_quote_count = 0 json_quote_count = 0
in_json = False in_json = False
got_json = False got_json = False
action_cache = "" action_cache = ''
action_str = "action:" action_str = 'action:'
action_idx = 0 action_idx = 0
thought_cache = "" thought_cache = ''
thought_str = "thought:" thought_str = 'thought:'
thought_idx = 0 thought_idx = 0
for response in llm_response: for response in llm_response:
if response.delta.usage: if response.delta.usage:
usage_dict["usage"] = response.delta.usage usage_dict['usage'] = response.delta.usage
response = response.delta.message.content response = response.delta.message.content
if not isinstance(response, str): if not isinstance(response, str):
continue continue
@ -73,24 +72,24 @@ class CotAgentOutputParser:
index = 0 index = 0
while index < len(response): while index < len(response):
steps = 1 steps = 1
delta = response[index : index + steps] delta = response[index:index+steps]
last_character = response[index - 1] if index > 0 else "" last_character = response[index-1] if index > 0 else ''
if delta == "`": if delta == '`':
code_block_cache += delta code_block_cache += delta
code_block_delimiter_count += 1 code_block_delimiter_count += 1
else: else:
if not in_code_block: if not in_code_block:
if code_block_delimiter_count > 0: if code_block_delimiter_count > 0:
yield code_block_cache yield code_block_cache
code_block_cache = "" code_block_cache = ''
else: else:
code_block_cache += delta code_block_cache += delta
code_block_delimiter_count = 0 code_block_delimiter_count = 0
if not in_code_block and not in_json: if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0: if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in {"\n", " ", ""}: if last_character not in ['\n', ' ', '']:
index += steps index += steps
yield delta yield delta
continue continue
@ -98,7 +97,7 @@ class CotAgentOutputParser:
action_cache += delta action_cache += delta
action_idx += 1 action_idx += 1
if action_idx == len(action_str): if action_idx == len(action_str):
action_cache = "" action_cache = ''
action_idx = 0 action_idx = 0
index += steps index += steps
continue continue
@ -106,18 +105,18 @@ class CotAgentOutputParser:
action_cache += delta action_cache += delta
action_idx += 1 action_idx += 1
if action_idx == len(action_str): if action_idx == len(action_str):
action_cache = "" action_cache = ''
action_idx = 0 action_idx = 0
index += steps index += steps
continue continue
else: else:
if action_cache: if action_cache:
yield action_cache yield action_cache
action_cache = "" action_cache = ''
action_idx = 0 action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0: if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in {"\n", " ", ""}: if last_character not in ['\n', ' ', '']:
index += steps index += steps
yield delta yield delta
continue continue
@ -125,7 +124,7 @@ class CotAgentOutputParser:
thought_cache += delta thought_cache += delta
thought_idx += 1 thought_idx += 1
if thought_idx == len(thought_str): if thought_idx == len(thought_str):
thought_cache = "" thought_cache = ''
thought_idx = 0 thought_idx = 0
index += steps index += steps
continue continue
@ -133,31 +132,31 @@ class CotAgentOutputParser:
thought_cache += delta thought_cache += delta
thought_idx += 1 thought_idx += 1
if thought_idx == len(thought_str): if thought_idx == len(thought_str):
thought_cache = "" thought_cache = ''
thought_idx = 0 thought_idx = 0
index += steps index += steps
continue continue
else: else:
if thought_cache: if thought_cache:
yield thought_cache yield thought_cache
thought_cache = "" thought_cache = ''
thought_idx = 0 thought_idx = 0
if code_block_delimiter_count == 3: if code_block_delimiter_count == 3:
if in_code_block: if in_code_block:
yield from extra_json_from_code_block(code_block_cache) yield from extra_json_from_code_block(code_block_cache)
code_block_cache = "" code_block_cache = ''
in_code_block = not in_code_block in_code_block = not in_code_block
code_block_delimiter_count = 0 code_block_delimiter_count = 0
if not in_code_block: if not in_code_block:
# handle single json # handle single json
if delta == "{": if delta == '{':
json_quote_count += 1 json_quote_count += 1
in_json = True in_json = True
json_cache += delta json_cache += delta
elif delta == "}": elif delta == '}':
json_cache += delta json_cache += delta
if json_quote_count > 0: if json_quote_count > 0:
json_quote_count -= 1 json_quote_count -= 1
@ -173,12 +172,12 @@ class CotAgentOutputParser:
if got_json: if got_json:
got_json = False got_json = False
yield parse_action(json_cache) yield parse_action(json_cache)
json_cache = "" json_cache = ''
json_quote_count = 0 json_quote_count = 0
in_json = False in_json = False
if not in_code_block and not in_json: if not in_code_block and not in_json:
yield delta.replace("`", "") yield delta.replace('`', '')
index += steps index += steps
@ -187,3 +186,4 @@ class CotAgentOutputParser:
if json_cache: if json_cache:
yield parse_action(json_cache) yield parse_action(json_cache)

View File

@ -41,8 +41,7 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
{{historic_messages}} {{historic_messages}}
Question: {{query}} Question: {{query}}
{{agent_scratchpad}} {{agent_scratchpad}}
Thought:""" # noqa: E501 Thought:"""
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
Thought:""" Thought:"""
@ -87,20 +86,19 @@ Action:
``` ```
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
""" # noqa: E501 """
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = "" ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = { REACT_PROMPT_TEMPLATES = {
"english": { 'english': {
"chat": { 'chat': {
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES, 'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES, 'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
},
"completion": {
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
}, },
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
} }
} }

View File

@ -26,24 +26,34 @@ class BaseAppConfigManager:
config_dict = dict(config_dict.items()) config_dict = dict(config_dict.items())
additional_features = AppAdditionalFeatures() additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
config=config_dict
)
additional_features.file_upload = FileUploadConfigManager.convert( additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT} config=config_dict,
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
) )
additional_features.opening_statement, additional_features.suggested_questions = ( additional_features.opening_statement, additional_features.suggested_questions = \
OpeningStatementConfigManager.convert(config=config_dict) OpeningStatementConfigManager.convert(
) config=config_dict
)
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert( additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
config=config_dict config=config_dict
) )
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict) additional_features.more_like_this = MoreLikeThisConfigManager.convert(
config=config_dict
)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict) additional_features.speech_to_text = SpeechToTextConfigManager.convert(
config=config_dict
)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict) additional_features.text_to_speech = TextToSpeechConfigManager.convert(
config=config_dict
)
return additional_features return additional_features

View File

@ -7,24 +7,25 @@ from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager: class SensitiveWordAvoidanceConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]: def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
if not sensitive_word_avoidance_dict: if not sensitive_word_avoidance_dict:
return None return None
if sensitive_word_avoidance_dict.get("enabled"): if sensitive_word_avoidance_dict.get('enabled'):
return SensitiveWordAvoidanceEntity( return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get("type"), type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get("config"), config=sensitive_word_avoidance_dict.get('config'),
) )
else: else:
return None return None
@classmethod @classmethod
def validate_and_set_defaults( def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
cls, tenant_id, config: dict, only_structure_validate: bool = False -> tuple[dict, list[str]]:
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"): if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False} config["sensitive_word_avoidance"] = {
"enabled": False
}
if not isinstance(config["sensitive_word_avoidance"], dict): if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type") raise ValueError("sensitive_word_avoidance must be of dict type")
@ -40,6 +41,10 @@ class SensitiveWordAvoidanceConfigManager:
typ = config["sensitive_word_avoidance"]["type"] typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config) ModerationFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=sensitive_word_avoidance_config
)
return config, ["sensitive_word_avoidance"] return config, ["sensitive_word_avoidance"]

View File

@ -12,70 +12,67 @@ class AgentConfigManager:
:param config: model config args :param config: model config args
""" """
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]: if 'agent_mode' in config and config['agent_mode'] \
agent_dict = config.get("agent_mode", {}) and 'enabled' in config['agent_mode']:
agent_strategy = agent_dict.get("strategy", "cot")
if agent_strategy == "function_call": agent_dict = config.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
strategy = AgentEntity.Strategy.FUNCTION_CALLING strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy in {"cot", "react"}: elif agent_strategy == 'cot' or agent_strategy == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else: else:
# old configs, try to detect default strategy # old configs, try to detect default strategy
if config["model"]["provider"] == "openai": if config['model']['provider'] == 'openai':
strategy = AgentEntity.Strategy.FUNCTION_CALLING strategy = AgentEntity.Strategy.FUNCTION_CALLING
else: else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = [] agent_tools = []
for tool in agent_dict.get("tools", []): for tool in agent_dict.get('tools', []):
keys = tool.keys() keys = tool.keys()
if len(keys) >= 4: if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]: if "enabled" not in tool or not tool["enabled"]:
continue continue
agent_tool_properties = { agent_tool_properties = {
"provider_type": tool["provider_type"], 'provider_type': tool['provider_type'],
"provider_id": tool["provider_id"], 'provider_id': tool['provider_id'],
"tool_name": tool["tool_name"], 'tool_name': tool['tool_name'],
"tool_parameters": tool.get("tool_parameters", {}), 'tool_parameters': tool.get('tool_parameters', {})
} }
agent_tools.append(AgentToolEntity(**agent_tool_properties)) agent_tools.append(AgentToolEntity(**agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { if 'strategy' in config['agent_mode'] and \
"react_router", config['agent_mode']['strategy'] not in ['react_router', 'router']:
"router", agent_prompt = agent_dict.get('prompt', None) or {}
}:
agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode # check model mode
model_mode = config.get("model", {}).get("mode", "completion") model_mode = config.get('model', {}).get('mode', 'completion')
if model_mode == "completion": if model_mode == 'completion':
agent_prompt_entity = AgentPromptEntity( agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get( first_prompt=agent_prompt.get('first_prompt',
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"] REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
), next_iteration=agent_prompt.get('next_iteration',
next_iteration=agent_prompt.get( REACT_PROMPT_TEMPLATES['english']['completion'][
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"] 'agent_scratchpad']),
),
) )
else: else:
agent_prompt_entity = AgentPromptEntity( agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get( first_prompt=agent_prompt.get('first_prompt',
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"] REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
), next_iteration=agent_prompt.get('next_iteration',
next_iteration=agent_prompt.get( REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
),
) )
return AgentEntity( return AgentEntity(
provider=config["model"]["provider"], provider=config['model']['provider'],
model=config["model"]["name"], model=config['model']['name'],
strategy=strategy, strategy=strategy,
prompt=agent_prompt_entity, prompt=agent_prompt_entity,
tools=agent_tools, tools=agent_tools,
max_iteration=agent_dict.get("max_iteration", 5), max_iteration=agent_dict.get('max_iteration', 5)
) )
return None return None

View File

@ -15,38 +15,39 @@ class DatasetConfigManager:
:param config: model config args :param config: model config args
""" """
dataset_ids = [] dataset_ids = []
if "datasets" in config.get("dataset_configs", {}): if 'datasets' in config.get('dataset_configs', {}):
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) datasets = config.get('dataset_configs', {}).get('datasets', {
'strategy': 'router',
'datasets': []
})
for dataset in datasets.get("datasets", []): for dataset in datasets.get('datasets', []):
keys = list(dataset.keys()) keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != "dataset": if len(keys) == 0 or keys[0] != 'dataset':
continue continue
dataset = dataset["dataset"] dataset = dataset['dataset']
if "enabled" not in dataset or not dataset["enabled"]: if 'enabled' not in dataset or not dataset['enabled']:
continue continue
dataset_id = dataset.get("id", None) dataset_id = dataset.get('id', None)
if dataset_id: if dataset_id:
dataset_ids.append(dataset_id) dataset_ids.append(dataset_id)
if ( if 'agent_mode' in config and config['agent_mode'] \
"agent_mode" in config and 'enabled' in config['agent_mode'] \
and config["agent_mode"] and config['agent_mode']['enabled']:
and "enabled" in config["agent_mode"]
and config["agent_mode"]["enabled"]
):
agent_dict = config.get("agent_mode", {})
for tool in agent_dict.get("tools", []): agent_dict = config.get('agent_mode', {})
for tool in agent_dict.get('tools', []):
keys = tool.keys() keys = tool.keys()
if len(keys) == 1: if len(keys) == 1:
# old standard # old standard
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key != "dataset": if key != 'dataset':
continue continue
tool_item = tool[key] tool_item = tool[key]
@ -54,28 +55,30 @@ class DatasetConfigManager:
if "enabled" not in tool_item or not tool_item["enabled"]: if "enabled" not in tool_item or not tool_item["enabled"]:
continue continue
dataset_id = tool_item["id"] dataset_id = tool_item['id']
dataset_ids.append(dataset_id) dataset_ids.append(dataset_id)
if len(dataset_ids) == 0: if len(dataset_ids) == 0:
return None return None
# dataset configs # dataset configs
if "dataset_configs" in config and config.get("dataset_configs"): if 'dataset_configs' in config and config.get('dataset_configs'):
dataset_configs = config.get("dataset_configs") dataset_configs = config.get('dataset_configs')
else: else:
dataset_configs = {"retrieval_model": "multiple"} dataset_configs = {
query_variable = config.get("dataset_query_variable") 'retrieval_model': 'multiple'
}
query_variable = config.get('dataset_query_variable')
if dataset_configs["retrieval_model"] == "single": if dataset_configs['retrieval_model'] == 'single':
return DatasetEntity( return DatasetEntity(
dataset_ids=dataset_ids, dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity( retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable, query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"] dataset_configs['retrieval_model']
), )
), )
) )
else: else:
return DatasetEntity( return DatasetEntity(
@ -83,15 +86,15 @@ class DatasetConfigManager:
retrieve_config=DatasetRetrieveConfigEntity( retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable, query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs["retrieval_model"] dataset_configs['retrieval_model']
), ),
top_k=dataset_configs.get("top_k", 4), top_k=dataset_configs.get('top_k', 4),
score_threshold=dataset_configs.get("score_threshold"), score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get("reranking_model"), reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get("weights"), weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get("reranking_enabled", True), reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
), )
) )
@classmethod @classmethod
@ -108,10 +111,13 @@ class DatasetConfigManager:
# dataset_configs # dataset_configs
if not config.get("dataset_configs"): if not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"} config["dataset_configs"] = {'retrieval_model': 'single'}
if not config["dataset_configs"].get("datasets"): if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} config["dataset_configs"]["datasets"] = {
"strategy": "router",
"datasets": []
}
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
@ -119,9 +125,8 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( need_manual_query_datasets = (config.get("dataset_configs")
"datasets", {} and config["dataset_configs"].get("datasets", {}).get("datasets"))
).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION: if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion # Only check when mode is completion
@ -143,7 +148,10 @@ class DatasetConfigManager:
""" """
# Extract dataset config for legacy compatibility # Extract dataset config for legacy compatibility
if not config.get("agent_mode"): if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []} config["agent_mode"] = {
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict): if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type") raise ValueError("agent_mode must be of object type")
@ -167,7 +175,7 @@ class DatasetConfigManager:
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
has_datasets = False has_datasets = False
if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]:
for tool in config["agent_mode"]["tools"]: for tool in config["agent_mode"]["tools"]:
key = list(tool.keys())[0] key = list(tool.keys())[0]
if key == "dataset": if key == "dataset":
@ -180,7 +188,7 @@ class DatasetConfigManager:
if not isinstance(tool_item["enabled"], bool): if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type") raise ValueError("enabled in agent_mode.tools must be of boolean type")
if "id" not in tool_item: if 'id' not in tool_item:
raise ValueError("id is required in dataset") raise ValueError("id is required in dataset")
try: try:

View File

@ -11,7 +11,9 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: def convert(cls, app_config: EasyUIBasedAppConfig,
skip_check: bool = False) \
-> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
@ -23,7 +25,9 @@ class ModelConfigConverter:
provider_manager = ProviderManager() provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle( provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM tenant_id=app_config.tenant_id,
provider=model_config.provider,
model_type=ModelType.LLM
) )
provider_name = provider_model_bundle.configuration.provider.provider provider_name = provider_model_bundle.configuration.provider.provider
@ -34,7 +38,8 @@ class ModelConfigConverter:
# check model credentials # check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials( model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_config.model model_type=ModelType.LLM,
model=model_config.model
) )
if model_credentials is None: if model_credentials is None:
@ -46,7 +51,8 @@ class ModelConfigConverter:
if not skip_check: if not skip_check:
# check model # check model
provider_model = provider_model_bundle.configuration.get_provider_model( provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM model=model_config.model,
model_type=ModelType.LLM
) )
if provider_model is None: if provider_model is None:
@ -63,18 +69,24 @@ class ModelConfigConverter:
# model config # model config
completion_params = model_config.parameters completion_params = model_config.parameters
stop = [] stop = []
if "stop" in completion_params: if 'stop' in completion_params:
stop = completion_params["stop"] stop = completion_params['stop']
del completion_params["stop"] del completion_params['stop']
# get model mode # get model mode
model_mode = model_config.mode model_mode = model_config.mode
if not model_mode: if not model_mode:
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials) mode_enum = model_type_instance.get_model_mode(
model=model_config.model,
credentials=model_credentials
)
model_mode = mode_enum.value model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) model_schema = model_type_instance.get_model_schema(
model_config.model,
model_credentials
)
if not skip_check and not model_schema: if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")

View File

@ -13,23 +13,23 @@ class ModelConfigManager:
:param config: model config args :param config: model config args
""" """
# model config # model config
model_config = config.get("model") model_config = config.get('model')
if not model_config: if not model_config:
raise ValueError("model is required") raise ValueError("model is required")
completion_params = model_config.get("completion_params") completion_params = model_config.get('completion_params')
stop = [] stop = []
if "stop" in completion_params: if 'stop' in completion_params:
stop = completion_params["stop"] stop = completion_params['stop']
del completion_params["stop"] del completion_params['stop']
# get model mode # get model mode
model_mode = model_config.get("mode") model_mode = model_config.get('mode')
return ModelConfigEntity( return ModelConfigEntity(
provider=config["model"]["provider"], provider=config['model']['provider'],
model=config["model"]["name"], model=config['model']['name'],
mode=model_mode, mode=model_mode,
parameters=completion_params, parameters=completion_params,
stop=stop, stop=stop,
@ -43,7 +43,7 @@ class ModelConfigManager:
:param tenant_id: tenant id :param tenant_id: tenant id
:param config: app model config args :param config: app model config args
""" """
if "model" not in config: if 'model' not in config:
raise ValueError("model is required") raise ValueError("model is required")
if not isinstance(config["model"], dict): if not isinstance(config["model"], dict):
@ -52,16 +52,17 @@ class ModelConfigManager:
# model.provider # model.provider
provider_entities = model_provider_factory.get_providers() provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities] model_provider_names = [provider.provider for provider in provider_entities]
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names: if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name # model.name
if "name" not in config["model"]: if 'name' not in config["model"]:
raise ValueError("model.name is required") raise ValueError("model.name is required")
provider_manager = ProviderManager() provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models( models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"], model_type=ModelType.LLM provider=config["model"]["provider"],
model_type=ModelType.LLM
) )
if not models: if not models:
@ -79,12 +80,12 @@ class ModelConfigManager:
# model.mode # model.mode
if model_mode: if model_mode:
config["model"]["mode"] = model_mode config['model']["mode"] = model_mode
else: else:
config["model"]["mode"] = "completion" config['model']["mode"] = "completion"
# model.completion_params # model.completion_params
if "completion_params" not in config["model"]: if 'completion_params' not in config["model"]:
raise ValueError("model.completion_params is required") raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = cls.validate_model_completion_params( config["model"]["completion_params"] = cls.validate_model_completion_params(
@ -100,7 +101,7 @@ class ModelConfigManager:
raise ValueError("model.completion_params must be of object type") raise ValueError("model.completion_params must be of object type")
# stop # stop
if "stop" not in cp: if 'stop' not in cp:
cp["stop"] = [] cp["stop"] = []
elif not isinstance(cp["stop"], list): elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type") raise ValueError("stop in model.completion_params must be of list type")

View File

@ -14,33 +14,39 @@ class PromptTemplateConfigManager:
if not config.get("prompt_type"): if not config.get("prompt_type"):
raise ValueError("prompt_type is required") raise ValueError("prompt_type is required")
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"]) prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = config.get("pre_prompt", "") simple_prompt_template = config.get("pre_prompt", "")
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template) return PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
else: else:
advanced_chat_prompt_template = None advanced_chat_prompt_template = None
chat_prompt_config = config.get("chat_prompt_config", {}) chat_prompt_config = config.get("chat_prompt_config", {})
if chat_prompt_config: if chat_prompt_config:
chat_prompt_messages = [] chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []): for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append( chat_prompt_messages.append({
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} "text": message["text"],
) "role": PromptMessageRole.value_of(message["role"])
})
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_completion_prompt_template = None advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {}) completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config: if completion_prompt_config:
completion_prompt_template_params = { completion_prompt_template_params = {
"prompt": completion_prompt_config["prompt"]["text"], 'prompt': completion_prompt_config['prompt']['text'],
} }
if "conversation_histories_role" in completion_prompt_config: if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params["role_prefix"] = { completion_prompt_template_params['role_prefix'] = {
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"], 'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
} }
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
@ -50,7 +56,7 @@ class PromptTemplateConfigManager:
return PromptTemplateEntity( return PromptTemplateEntity(
prompt_type=prompt_type, prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template, advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template, advanced_completion_prompt_template=advanced_completion_prompt_template
) )
@classmethod @classmethod
@ -66,7 +72,7 @@ class PromptTemplateConfigManager:
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType] prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config["prompt_type"] not in prompt_type_vals: if config['prompt_type'] not in prompt_type_vals:
raise ValueError(f"prompt_type must be in {prompt_type_vals}") raise ValueError(f"prompt_type must be in {prompt_type_vals}")
# chat_prompt_config # chat_prompt_config
@ -83,28 +89,27 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict): if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type") raise ValueError("completion_prompt_config must be of object type")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value: if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]: if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError( raise ValueError("chat_prompt_config or completion_prompt_config is required "
"chat_prompt_config or completion_prompt_config is required when prompt_type is advanced" "when prompt_type is advanced")
)
model_mode_vals = [mode.value for mode in ModelMode] model_mode_vals = [mode.value for mode in ModelMode]
if config["model"]["mode"] not in model_mode_vals: if config['model']["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced") raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value: if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
if not user_prefix: if not user_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human" config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
if not assistant_prefix: if not assistant_prefix:
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant" config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
if config["model"]["mode"] == ModelMode.CHAT.value: if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config["chat_prompt_config"]["prompt"] prompt_list = config['chat_prompt_config']['prompt']
if len(prompt_list) > 10: if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10") raise ValueError("prompt messages must be less than 10")

View File

@ -16,49 +16,51 @@ class BasicVariablesConfigManager:
variable_entities = [] variable_entities = []
# old external_data_tools # old external_data_tools
external_data_tools = config.get("external_data_tools", []) external_data_tools = config.get('external_data_tools', [])
for external_data_tool in external_data_tools: for external_data_tool in external_data_tools:
if "enabled" not in external_data_tool or not external_data_tool["enabled"]: if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
continue continue
external_data_variables.append( external_data_variables.append(
ExternalDataVariableEntity( ExternalDataVariableEntity(
variable=external_data_tool["variable"], variable=external_data_tool['variable'],
type=external_data_tool["type"], type=external_data_tool['type'],
config=external_data_tool["config"], config=external_data_tool['config']
) )
) )
# variables and external_data_tools # variables and external_data_tools
for variables in config.get("user_input_form", []): for variables in config.get('user_input_form', []):
variable_type = list(variables.keys())[0] variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type] variable = variables[variable_type]
if "config" not in variable: if 'config' not in variable:
continue continue
external_data_variables.append( external_data_variables.append(
ExternalDataVariableEntity( ExternalDataVariableEntity(
variable=variable["variable"], type=variable["type"], config=variable["config"] variable=variable['variable'],
type=variable['type'],
config=variable['config']
) )
) )
elif variable_type in { elif variable_type in [
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH, VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER, VariableEntityType.NUMBER,
VariableEntityType.SELECT, VariableEntityType.SELECT,
}: ]:
variable = variables[variable_type] variable = variables[variable_type]
variable_entities.append( variable_entities.append(
VariableEntity( VariableEntity(
type=variable_type, type=variable_type,
variable=variable.get("variable"), variable=variable.get('variable'),
description=variable.get("description"), description=variable.get('description'),
label=variable.get("label"), label=variable.get('label'),
required=variable.get("required", False), required=variable.get('required', False),
max_length=variable.get("max_length"), max_length=variable.get('max_length'),
options=variable.get("options"), options=variable.get('options'),
default=variable.get("default"), default=variable.get('default'),
) )
) )
@ -97,17 +99,17 @@ class BasicVariablesConfigManager:
variables = [] variables = []
for item in config["user_input_form"]: for item in config["user_input_form"]:
key = list(item.keys())[0] key = list(item.keys())[0]
if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key] form_item = item[key]
if "label" not in form_item: if 'label' not in form_item:
raise ValueError("label is required in user_input_form") raise ValueError("label is required in user_input_form")
if not isinstance(form_item["label"], str): if not isinstance(form_item["label"], str):
raise ValueError("label in user_input_form must be of string type") raise ValueError("label in user_input_form must be of string type")
if "variable" not in form_item: if 'variable' not in form_item:
raise ValueError("variable is required in user_input_form") raise ValueError("variable is required in user_input_form")
if not isinstance(form_item["variable"], str): if not isinstance(form_item["variable"], str):
@ -115,24 +117,26 @@ class BasicVariablesConfigManager:
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None: if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, and cannot start with a number") raise ValueError("variable in user_input_form must be a string, "
"and cannot start with a number")
variables.append(form_item["variable"]) variables.append(form_item["variable"])
if "required" not in form_item or not form_item["required"]: if 'required' not in form_item or not form_item["required"]:
form_item["required"] = False form_item["required"] = False
if not isinstance(form_item["required"], bool): if not isinstance(form_item["required"], bool):
raise ValueError("required in user_input_form must be of boolean type") raise ValueError("required in user_input_form must be of boolean type")
if key == "select": if key == "select":
if "options" not in form_item or not form_item["options"]: if 'options' not in form_item or not form_item["options"]:
form_item["options"] = [] form_item["options"] = []
if not isinstance(form_item["options"], list): if not isinstance(form_item["options"], list):
raise ValueError("options in user_input_form must be a list of strings") raise ValueError("options in user_input_form must be a list of strings")
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]: if "default" in form_item and form_item['default'] \
and form_item["default"] not in form_item["options"]:
raise ValueError("default value in user_input_form must be in the options list") raise ValueError("default value in user_input_form must be in the options list")
return config, ["user_input_form"] return config, ["user_input_form"]
@ -164,6 +168,10 @@ class BasicVariablesConfigManager:
typ = tool["type"] typ = tool["type"]
config = tool["config"] config = tool["config"]
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config) ExternalDataToolFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=config
)
return config, ["external_data_tools"] return config, ["external_data_tools"]

View File

@ -12,7 +12,6 @@ class ModelConfigEntity(BaseModel):
""" """
Model Config Entity. Model Config Entity.
""" """
provider: str provider: str
model: str model: str
mode: Optional[str] = None mode: Optional[str] = None
@ -24,7 +23,6 @@ class AdvancedChatMessageEntity(BaseModel):
""" """
Advanced Chat Message Entity. Advanced Chat Message Entity.
""" """
text: str text: str
role: PromptMessageRole role: PromptMessageRole
@ -33,7 +31,6 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
""" """
Advanced Chat Prompt Template Entity. Advanced Chat Prompt Template Entity.
""" """
messages: list[AdvancedChatMessageEntity] messages: list[AdvancedChatMessageEntity]
@ -46,7 +43,6 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
""" """
Role Prefix Entity. Role Prefix Entity.
""" """
user: str user: str
assistant: str assistant: str
@ -64,12 +60,11 @@ class PromptTemplateEntity(BaseModel):
Prompt Type. Prompt Type.
'simple', 'advanced' 'simple', 'advanced'
""" """
SIMPLE = 'simple'
SIMPLE = "simple" ADVANCED = 'advanced'
ADVANCED = "advanced"
@classmethod @classmethod
def value_of(cls, value: str) -> "PromptType": def value_of(cls, value: str) -> 'PromptType':
""" """
Get value of given mode. Get value of given mode.
@ -79,7 +74,7 @@ class PromptTemplateEntity(BaseModel):
for mode in cls: for mode in cls:
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f"invalid prompt type value {value}") raise ValueError(f'invalid prompt type value {value}')
prompt_type: PromptType prompt_type: PromptType
simple_prompt_template: Optional[str] = None simple_prompt_template: Optional[str] = None
@ -92,7 +87,7 @@ class VariableEntityType(str, Enum):
SELECT = "select" SELECT = "select"
PARAGRAPH = "paragraph" PARAGRAPH = "paragraph"
NUMBER = "number" NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool" EXTERNAL_DATA_TOOL = "external-data-tool"
class VariableEntity(BaseModel): class VariableEntity(BaseModel):
@ -115,7 +110,6 @@ class ExternalDataVariableEntity(BaseModel):
""" """
External Data Variable Entity. External Data Variable Entity.
""" """
variable: str variable: str
type: str type: str
config: dict[str, Any] = {} config: dict[str, Any] = {}
@ -131,12 +125,11 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Strategy. Dataset Retrieve Strategy.
'single' or 'multiple' 'single' or 'multiple'
""" """
SINGLE = 'single'
SINGLE = "single" MULTIPLE = 'multiple'
MULTIPLE = "multiple"
@classmethod @classmethod
def value_of(cls, value: str) -> "RetrieveStrategy": def value_of(cls, value: str) -> 'RetrieveStrategy':
""" """
Get value of given mode. Get value of given mode.
@ -146,24 +139,25 @@ class DatasetRetrieveConfigEntity(BaseModel):
for mode in cls: for mode in cls:
if mode.value == value: if mode.value == value:
return mode return mode
raise ValueError(f"invalid retrieve strategy value {value}") raise ValueError(f'invalid retrieve strategy value {value}')
query_variable: Optional[str] = None # Only when app mode is completion query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None top_k: Optional[int] = None
score_threshold: Optional[float] = 0.0 score_threshold: Optional[float] = .0
rerank_mode: Optional[str] = "reranking_model" rerank_mode: Optional[str] = 'reranking_model'
reranking_model: Optional[dict] = None reranking_model: Optional[dict] = None
weights: Optional[dict] = None weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True reranking_enabled: Optional[bool] = True
class DatasetEntity(BaseModel): class DatasetEntity(BaseModel):
""" """
Dataset Config Entity. Dataset Config Entity.
""" """
dataset_ids: list[str] dataset_ids: list[str]
retrieve_config: DatasetRetrieveConfigEntity retrieve_config: DatasetRetrieveConfigEntity
@ -172,7 +166,6 @@ class SensitiveWordAvoidanceEntity(BaseModel):
""" """
Sensitive Word Avoidance Entity. Sensitive Word Avoidance Entity.
""" """
type: str type: str
config: dict[str, Any] = {} config: dict[str, Any] = {}
@ -181,7 +174,6 @@ class TextToSpeechEntity(BaseModel):
""" """
Sensitive Word Avoidance Entity. Sensitive Word Avoidance Entity.
""" """
enabled: bool enabled: bool
voice: Optional[str] = None voice: Optional[str] = None
language: Optional[str] = None language: Optional[str] = None
@ -191,11 +183,12 @@ class TracingConfigEntity(BaseModel):
""" """
Tracing Config Entity. Tracing Config Entity.
""" """
enabled: bool enabled: bool
tracing_provider: str tracing_provider: str
class AppAdditionalFeatures(BaseModel): class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileExtraConfig] = None file_upload: Optional[FileExtraConfig] = None
opening_statement: Optional[str] = None opening_statement: Optional[str] = None
@ -207,12 +200,10 @@ class AppAdditionalFeatures(BaseModel):
text_to_speech: Optional[TextToSpeechEntity] = None text_to_speech: Optional[TextToSpeechEntity] = None
trace_config: Optional[TracingConfigEntity] = None trace_config: Optional[TracingConfigEntity] = None
class AppConfig(BaseModel): class AppConfig(BaseModel):
""" """
Application Config Entity. Application Config Entity.
""" """
tenant_id: str tenant_id: str
app_id: str app_id: str
app_mode: AppMode app_mode: AppMode
@ -225,17 +216,15 @@ class EasyUIBasedAppModelConfigFrom(Enum):
""" """
App Model Config From. App Model Config From.
""" """
ARGS = 'args'
ARGS = "args" APP_LATEST_CONFIG = 'app-latest-config'
APP_LATEST_CONFIG = "app-latest-config" CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
class EasyUIBasedAppConfig(AppConfig): class EasyUIBasedAppConfig(AppConfig):
""" """
Easy UI Based App Config Entity. Easy UI Based App Config Entity.
""" """
app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str app_model_config_id: str
app_model_config_dict: dict app_model_config_dict: dict
@ -249,5 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
""" """
Workflow UI Based App Config Entity. Workflow UI Based App Config Entity.
""" """
workflow_id: str workflow_id: str

View File

@ -13,19 +13,21 @@ class FileUploadConfigManager:
:param config: model config args :param config: model config args
:param is_vision: if True, the feature is vision feature :param is_vision: if True, the feature is vision feature
""" """
file_upload_dict = config.get("file_upload") file_upload_dict = config.get('file_upload')
if file_upload_dict: if file_upload_dict:
if file_upload_dict.get("image"): if file_upload_dict.get('image'):
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]: if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
image_config = { image_config = {
"number_limits": file_upload_dict["image"]["number_limits"], 'number_limits': file_upload_dict['image']['number_limits'],
"transfer_methods": file_upload_dict["image"]["transfer_methods"], 'transfer_methods': file_upload_dict['image']['transfer_methods']
} }
if is_vision: if is_vision:
image_config["detail"] = file_upload_dict["image"]["detail"] image_config['detail'] = file_upload_dict['image']['detail']
return FileExtraConfig(image_config=image_config) return FileExtraConfig(
image_config=image_config
)
return None return None
@ -47,21 +49,21 @@ class FileUploadConfigManager:
if not config["file_upload"].get("image"): if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False} config["file_upload"]["image"] = {"enabled": False}
if config["file_upload"]["image"]["enabled"]: if config['file_upload']['image']['enabled']:
number_limits = config["file_upload"]["image"]["number_limits"] number_limits = config['file_upload']['image']['number_limits']
if number_limits < 1 or number_limits > 6: if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]") raise ValueError("number_limits must be in [1, 6]")
if is_vision: if is_vision:
detail = config["file_upload"]["image"]["detail"] detail = config['file_upload']['image']['detail']
if detail not in {"high", "low"}: if detail not in ['high', 'low']:
raise ValueError("detail must be in ['high', 'low']") raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config["file_upload"]["image"]["transfer_methods"] transfer_methods = config['file_upload']['image']['transfer_methods']
if not isinstance(transfer_methods, list): if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type") raise ValueError("transfer_methods must be of list type")
for method in transfer_methods: for method in transfer_methods:
if method not in {"remote_url", "local_file"}: if method not in ['remote_url', 'local_file']:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"] return config, ["file_upload"]

View File

@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
:param config: model config args :param config: model config args
""" """
more_like_this = False more_like_this = False
more_like_this_dict = config.get("more_like_this") more_like_this_dict = config.get('more_like_this')
if more_like_this_dict: if more_like_this_dict:
if more_like_this_dict.get("enabled"): if more_like_this_dict.get('enabled'):
more_like_this = True more_like_this = True
return more_like_this return more_like_this
@ -22,7 +22,9 @@ class MoreLikeThisConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("more_like_this"): if not config.get("more_like_this"):
config["more_like_this"] = {"enabled": False} config["more_like_this"] = {
"enabled": False
}
if not isinstance(config["more_like_this"], dict): if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type") raise ValueError("more_like_this must be of dict type")

View File

@ -1,3 +1,5 @@
class OpeningStatementConfigManager: class OpeningStatementConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> tuple[str, list]: def convert(cls, config: dict) -> tuple[str, list]:
@ -7,10 +9,10 @@ class OpeningStatementConfigManager:
:param config: model config args :param config: model config args
""" """
# opening statement # opening statement
opening_statement = config.get("opening_statement") opening_statement = config.get('opening_statement')
# suggested questions # suggested questions
suggested_questions_list = config.get("suggested_questions") suggested_questions_list = config.get('suggested_questions')
return opening_statement, suggested_questions_list return opening_statement, suggested_questions_list

View File

@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
@classmethod @classmethod
def convert(cls, config: dict) -> bool: def convert(cls, config: dict) -> bool:
show_retrieve_source = False show_retrieve_source = False
retriever_resource_dict = config.get("retriever_resource") retriever_resource_dict = config.get('retriever_resource')
if retriever_resource_dict: if retriever_resource_dict:
if retriever_resource_dict.get("enabled"): if retriever_resource_dict.get('enabled'):
show_retrieve_source = True show_retrieve_source = True
return show_retrieve_source return show_retrieve_source
@ -17,7 +17,9 @@ class RetrievalResourceConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("retriever_resource"): if not config.get("retriever_resource"):
config["retriever_resource"] = {"enabled": False} config["retriever_resource"] = {
"enabled": False
}
if not isinstance(config["retriever_resource"], dict): if not isinstance(config["retriever_resource"], dict):
raise ValueError("retriever_resource must be of dict type") raise ValueError("retriever_resource must be of dict type")

View File

@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
:param config: model config args :param config: model config args
""" """
speech_to_text = False speech_to_text = False
speech_to_text_dict = config.get("speech_to_text") speech_to_text_dict = config.get('speech_to_text')
if speech_to_text_dict: if speech_to_text_dict:
if speech_to_text_dict.get("enabled"): if speech_to_text_dict.get('enabled'):
speech_to_text = True speech_to_text = True
return speech_to_text return speech_to_text
@ -22,7 +22,9 @@ class SpeechToTextConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("speech_to_text"): if not config.get("speech_to_text"):
config["speech_to_text"] = {"enabled": False} config["speech_to_text"] = {
"enabled": False
}
if not isinstance(config["speech_to_text"], dict): if not isinstance(config["speech_to_text"], dict):
raise ValueError("speech_to_text must be of dict type") raise ValueError("speech_to_text must be of dict type")

View File

@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: model config args :param config: model config args
""" """
suggested_questions_after_answer = False suggested_questions_after_answer = False
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer") suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
if suggested_questions_after_answer_dict: if suggested_questions_after_answer_dict:
if suggested_questions_after_answer_dict.get("enabled"): if suggested_questions_after_answer_dict.get('enabled'):
suggested_questions_after_answer = True suggested_questions_after_answer = True
return suggested_questions_after_answer return suggested_questions_after_answer
@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("suggested_questions_after_answer"): if not config.get("suggested_questions_after_answer"):
config["suggested_questions_after_answer"] = {"enabled": False} config["suggested_questions_after_answer"] = {
"enabled": False
}
if not isinstance(config["suggested_questions_after_answer"], dict): if not isinstance(config["suggested_questions_after_answer"], dict):
raise ValueError("suggested_questions_after_answer must be of dict type") raise ValueError("suggested_questions_after_answer must be of dict type")
if ( if "enabled" not in config["suggested_questions_after_answer"] or not \
"enabled" not in config["suggested_questions_after_answer"] config["suggested_questions_after_answer"]["enabled"]:
or not config["suggested_questions_after_answer"]["enabled"]
):
config["suggested_questions_after_answer"]["enabled"] = False config["suggested_questions_after_answer"]["enabled"] = False
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):

View File

@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
:param config: model config args :param config: model config args
""" """
text_to_speech = None text_to_speech = None
text_to_speech_dict = config.get("text_to_speech") text_to_speech_dict = config.get('text_to_speech')
if text_to_speech_dict: if text_to_speech_dict:
if text_to_speech_dict.get("enabled"): if text_to_speech_dict.get('enabled'):
text_to_speech = TextToSpeechEntity( text_to_speech = TextToSpeechEntity(
enabled=text_to_speech_dict.get("enabled"), enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get("voice"), voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get("language"), language=text_to_speech_dict.get('language'),
) )
return text_to_speech return text_to_speech
@ -29,7 +29,11 @@ class TextToSpeechConfigManager:
:param config: app model config args :param config: app model config args
""" """
if not config.get("text_to_speech"): if not config.get("text_to_speech"):
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""} config["text_to_speech"] = {
"enabled": False,
"voice": "",
"language": ""
}
if not isinstance(config["text_to_speech"], dict): if not isinstance(config["text_to_speech"], dict):
raise ValueError("text_to_speech must be of dict type") raise ValueError("text_to_speech must be of dict type")

View File

@ -1,3 +1,4 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.app_config.entities import WorkflowUIBasedAppConfig
@ -18,13 +19,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
""" """
Advanced Chatbot App Config Entity. Advanced Chatbot App Config Entity.
""" """
pass pass
class AdvancedChatAppConfigManager(BaseAppConfigManager): class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig: def get_app_config(cls, app_model: App,
workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
@ -33,9 +34,13 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
app_id=app_model.id, app_id=app_model.id,
app_mode=app_mode, app_mode=app_mode,
workflow_id=workflow.id, workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
variables=WorkflowVariablesConfigManager.convert(workflow=workflow), config=features_dict
additional_features=cls.convert_features(features_dict, app_mode), ),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
) )
return app_config return app_config
@ -53,7 +58,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config, is_vision=False config=config,
is_vision=False
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
@ -63,8 +69,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config config)
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
@ -81,7 +86,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
) )
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
@ -91,3 +98,4 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
filtered_config = {key: config.get(key) for key in related_config_keys} filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config return filtered_config

View File

@ -4,10 +4,12 @@ import os
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Literal, Optional, Union, overload from typing import Union
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -15,54 +17,36 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow from models.workflow import ConversationVariable, Workflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator): class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate( def generate(
self, self, app_model: App,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(
self,
app_model: App,
workflow: Workflow, workflow: Workflow,
user: Union[Account, EndUser], user: Union[Account, EndUser],
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]: ):
""" """
Generate App response. Generate App response.
@ -73,37 +57,44 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
""" """
if not args.get("query"): if not args.get('query'):
raise ValueError("query is required") raise ValueError('query is required')
query = args["query"] query = args['query']
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError("query must be a string") raise ValueError('query must be a string')
query = query.replace("\x00", "") query = query.replace('\x00', '')
inputs = args["inputs"] inputs = args['inputs']
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)} extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', False)
}
# get conversation # get conversation
conversation = None conversation = None
conversation_id = args.get("conversation_id") conversation_id = args.get('conversation_id')
if conversation_id: if conversation_id:
conversation = self._get_conversation_by_user( conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
app_model=app_model, conversation_id=conversation_id, user=user
)
# parse files # parse files
files = args["files"] if args.get("files") else [] files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
# convert to app config # convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id user_id = user.id if isinstance(user, Account) else user.session_id
@ -125,7 +116,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
trace_manager=trace_manager, trace_manager=trace_manager
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -135,12 +126,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from, invoke_from=invoke_from,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
conversation=conversation, conversation=conversation,
stream=stream, stream=stream
) )
def single_iteration_generate( def single_iteration_generate(self, app_model: App,
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True workflow: Workflow,
) -> dict[str, Any] | Generator[str, Any, None]: node_id: str,
user: Account,
args: dict,
stream: bool = True):
""" """
Generate App response. Generate App response.
@ -152,29 +146,43 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream :param stream: is stream
""" """
if not node_id: if not node_id:
raise ValueError("node_id is required") raise ValueError('node_id is required')
if args.get("inputs") is None: if args.get('inputs') is None:
raise ValueError("inputs is required") raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# get conversation
conversation = None
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# convert to app config # convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
# init application generate entity # init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity( application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
conversation_id=None, conversation_id=conversation.id if conversation else None,
inputs={}, inputs={},
query="", query='',
files=[], files=[],
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False}, extras=extras,
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id, inputs=args["inputs"] node_id=node_id,
), inputs=args['inputs']
)
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@ -183,42 +191,32 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
conversation=None, conversation=conversation,
stream=stream, stream=stream
) )
def _generate( def _generate(self, *,
self, workflow: Workflow,
*, user: Union[Account, EndUser],
workflow: Workflow, invoke_from: InvokeFrom,
user: Union[Account, EndUser], application_generate_entity: AdvancedChatAppGenerateEntity,
invoke_from: InvokeFrom, conversation: Conversation | None = None,
application_generate_entity: AdvancedChatAppGenerateEntity, stream: bool = True):
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
:param workflow: Workflow
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False is_first_conversation = False
if not conversation: if not conversation:
is_first_conversation = True is_first_conversation = True
# init generate records # init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation) (
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation: if is_first_conversation:
# update conversation features # update conversation features
conversation.override_model_configs = workflow.features conversation.override_model_configs = workflow.features
db.session.commit() db.session.commit()
db.session.refresh(conversation) # db.session.refresh(conversation)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -227,21 +225,73 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id, message_id=message.id
) )
# new thread # Init conversation variables
worker_thread = threading.Thread( stmt = select(ConversationVariable).where(
target=self._generate_worker, ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": contextvars.copy_context(),
},
) )
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'message_id': message.id,
'context': contextvars.copy_context(),
})
worker_thread.start() worker_thread.start()
@ -256,17 +306,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) return AdvancedChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
def _generate_worker( def _generate_worker(self, flask_app: Flask,
self, application_generate_entity: AdvancedChatAppGenerateEntity,
flask_app: Flask, queue_manager: AppQueueManager,
application_generate_entity: AdvancedChatAppGenerateEntity, message_id: str,
queue_manager: AppQueueManager, context: contextvars.Context) -> None:
conversation_id: str,
message_id: str,
context: contextvars.Context,
) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@ -280,30 +329,40 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var.set(val) var.set(val)
with flask_app.app_context(): with flask_app.app_context():
try: try:
# get conversation and message runner = AdvancedChatAppRunner()
conversation = self._get_conversation(conversation_id) if application_generate_entity.single_iteration_run:
message = self._get_message(message_id) single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
# get message
message = self._get_message(message_id)
# chatbot app # chatbot app
runner = AdvancedChatAppRunner( runner = AdvancedChatAppRunner()
application_generate_entity=application_generate_entity, runner.run(
queue_manager=queue_manager, application_generate_entity=application_generate_entity,
conversation=conversation, queue_manager=queue_manager,
message=message, message=message
) )
except GenerateTaskStoppedException:
runner.run()
except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG", "false").lower() == "true": if os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:
@ -349,7 +408,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError() raise GenerateTaskStoppedException()
else: else:
logger.exception(e) logger.exception(e)
raise e raise e

View File

@ -21,11 +21,14 @@ class AudioTrunk:
self.status = status self.status = status
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace(): if not text_content or text_content.isspace():
return return
return model_instance.invoke_tts( return model_instance.invoke_tts(
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice content_text=text_content.strip(),
user="responding_tts",
tenant_id=tenant_id,
voice=voice
) )
@ -41,26 +44,28 @@ def _process_future(future_queue, audio_queue):
except Exception as e: except Exception as e:
logging.getLogger(__name__).warning(e) logging.getLogger(__name__).warning(e)
break break
audio_queue.put(AudioTrunk("finish", b"")) audio_queue.put(AudioTrunk("finish", b''))
class AppGeneratorTTSPublisher: class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str): def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.msg_text = "" self.msg_text = ''
self._audio_queue = queue.Queue() self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue() self._msg_queue = queue.Queue()
self.match = re.compile(r"[。.!?]") self.match = re.compile(r'[。.!?]')
self.model_manager = ModelManager() self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance( self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS tenant_id=self.tenant_id,
model_type=ModelType.TTS
) )
self.voices = self.model_instance.get_tts_voices() self.voices = self.model_instance.get_tts_voices()
values = [voice.get("value") for voice in self.voices] values = [voice.get('value') for voice in self.voices]
self.voice = voice self.voice = voice
if not voice or voice not in values: if not voice or voice not in values:
self.voice = self.voices[0].get("value") self.voice = self.voices[0].get('value')
self.MAX_SENTENCE = 2 self.MAX_SENTENCE = 2
self._last_audio_event = None self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start() self._runtime_thread = threading.Thread(target=self._runtime).start()
@ -80,9 +85,8 @@ class AppGeneratorTTSPublisher:
message = self._msg_queue.get() message = self._msg_queue.get()
if message is None: if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0: if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit( futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice self.model_instance, self.tenant_id, self.voice)
)
future_queue.put(futures_result) future_queue.put(futures_result)
break break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
@ -90,27 +94,28 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueTextChunkEvent): elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent): elif isinstance(message.event, QueueNodeSucceededEvent):
self.msg_text += message.event.outputs.get("output", "") self.msg_text += message.event.outputs.get('output', '')
self.last_message = message self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text) sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1 self.MAX_SENTENCE += 1
text_content = "".join(sentence_arr) text_content = ''.join(sentence_arr)
futures_result = self.executor.submit( futures_result = self.executor.submit(_invoiceTTS, text_content,
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice self.model_instance,
) self.tenant_id,
self.voice)
future_queue.put(futures_result) future_queue.put(futures_result)
if text_tmp: if text_tmp:
self.msg_text = text_tmp self.msg_text = text_tmp
else: else:
self.msg_text = "" self.msg_text = ''
except Exception as e: except Exception as e:
self.logger.warning(e) self.logger.warning(e)
break break
future_queue.put(None) future_queue.put(None)
def check_and_get_audio(self) -> AudioTrunk | None: def checkAndGetAudio(self) -> AudioTrunk | None:
try: try:
if self._last_audio_event and self._last_audio_event.status == "finish": if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor: if self.executor:

View File

@ -1,197 +1,145 @@
import logging import logging
import os import os
import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity, AdvancedChatAppGenerateEntity,
InvokeFrom, InvokeFrom,
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
QueueAnnotationReplyEvent, from core.moderation.base import ModerationException
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationError
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom from core.workflow.nodes.base_node import UserFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, EndUser, Message from models import App, Message, Workflow
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AdvancedChatAppRunner(WorkflowBasedAppRunner): class AdvancedChatAppRunner(AppRunner):
""" """
AdvancedChat Application Runner AdvancedChat Application Runner
""" """
def __init__( def run(
self, self,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation,
message: Message, message: Message,
) -> None: ) -> None:
""" """
Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param conversation: conversation :param conversation: conversation
:param message: message :param message: message
"""
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
def run(self) -> None:
"""
Run application
:return: :return:
""" """
app_config = self.application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config) app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first() app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow: if not workflow:
raise ValueError("Workflow not initialized") raise ValueError('Workflow not initialized')
user_id = None inputs = application_generate_entity.inputs
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: query = application_generate_entity.query
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = [] # moderation
if bool(os.environ.get("DEBUG", "False").lower() == "true"): if self.handle_input_moderation(
workflow_callbacks.append(WorkflowLoggingCallback()) queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
):
return
if self.application_generate_entity.single_iteration_run: # annotation reply
# if only single iteration run is requested if self.handle_annotation_reply(
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( app_record=app_record,
workflow=workflow, message=message,
node_id=self.application_generate_entity.single_iteration_run.node_id, query=query,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, queue_manager=queue_manager,
) app_generate_entity=application_generate_entity,
else: ):
inputs = self.application_generate_entity.inputs return
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id,
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
self.conversation.dialogue_count += 1
conversation_dialogue_count = self.conversation.dialogue_count
db.session.commit()
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
db.session.close() db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# RUN WORKFLOW # RUN WORKFLOW
workflow_entry = WorkflowEntry( workflow_engine_manager = WorkflowEngineManager()
tenant_id=workflow.tenant_id, workflow_engine_manager.run_workflow(
app_id=workflow.app_id, workflow=workflow,
workflow_id=workflow.id, user_id=application_generate_entity.user_id,
workflow_type=WorkflowType.value_of(workflow.type), user_from=UserFrom.ACCOUNT
graph=graph, if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
graph_config=workflow.graph_dict, else UserFrom.END_USER,
user_id=self.application_generate_entity.user_id, invoke_from=application_generate_entity.invoke_from,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
)
generator = workflow_entry.run(
callbacks=workflow_callbacks, callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
) )
for event in generator: def single_iteration_run(
self._handle_event(workflow_entry, event) self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def handle_input_moderation( def handle_input_moderation(
self, self,
queue_manager: AppQueueManager,
app_record: App, app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity, app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any], inputs: Mapping[str, Any],
@ -200,6 +148,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
) -> bool: ) -> bool:
""" """
Handle input moderation Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record :param app_record: app record
:param app_generate_entity: application generate entity :param app_generate_entity: application generate entity
:param inputs: inputs :param inputs: inputs
@ -217,20 +166,31 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query, query=query,
message_id=message_id, message_id=message_id,
) )
except ModerationError as e: except ModerationException as e:
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) self._stream_output(
queue_manager=queue_manager,
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
)
return True return True
return False return False
def handle_annotation_reply( def handle_annotation_reply(
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool: ) -> bool:
""" """
Handle annotation reply Handle annotation reply
:param app_record: app record :param app_record: app record
:param message: message :param message: message
:param query: query :param query: query
:param queue_manager: application queue manager
:param app_generate_entity: application generate entity :param app_generate_entity: application generate entity
""" """
# annotation reply # annotation reply
@ -243,21 +203,37 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
) )
if annotation_reply: if annotation_reply:
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)) queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
)
self._complete_with_stream_output( self._stream_output(
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY queue_manager=queue_manager,
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
) )
return True return True
return False return False
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
""" """
Direct output Direct output
:param queue_manager: application queue manager
:param text: text :param text: text
:param stream: stream
:return: :return:
""" """
self._publish_event(QueueTextChunkEvent(text=text)) if stream:
index = 0
for token in text:
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
self._publish_event(QueueStopEvent(stopped_by=stopped_by)) queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)

View File

@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response) blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = { response = {
"event": "message", 'event': 'message',
"task_id": blocking_response.task_id, 'task_id': blocking_response.task_id,
"id": blocking_response.data.id, 'id': blocking_response.data.id,
"message_id": blocking_response.data.message_id, 'message_id': blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id, 'conversation_id': blocking_response.data.conversation_id,
"mode": blocking_response.data.mode, 'mode': blocking_response.data.mode,
"answer": blocking_response.data.answer, 'answer': blocking_response.data.answer,
"metadata": blocking_response.data.metadata, 'metadata': blocking_response.data.metadata,
"created_at": blocking_response.data.created_at, 'created_at': blocking_response.data.created_at
} }
return response return response
@ -50,15 +50,13 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {}) metadata = response.get('metadata', {})
response["metadata"] = cls._get_simple_metadata(metadata) response['metadata'] = cls._get_simple_metadata(metadata)
return response return response
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -69,14 +67,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield "ping" yield 'ping'
continue continue
response_chunk = { response_chunk = {
"event": sub_stream_response.event.value, 'event': sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, 'conversation_id': chunk.conversation_id,
"message_id": chunk.message_id, 'message_id': chunk.message_id,
"created_at": chunk.created_at, 'created_at': chunk.created_at
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -87,9 +85,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk) yield json.dumps(response_chunk)
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -100,20 +96,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield "ping" yield 'ping'
continue continue
response_chunk = { response_chunk = {
"event": sub_stream_response.event.value, 'event': sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, 'conversation_id': chunk.conversation_id,
"message_id": chunk.message_id, 'message_id': chunk.message_id,
"created_at": chunk.created_at, 'created_at': chunk.created_at
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

View File

@ -2,8 +2,9 @@ import json
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union from typing import Any, Optional, Union, cast
import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -21,9 +22,6 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent, QueueNodeFailedEvent,
QueueNodeStartedEvent, QueueNodeStartedEvent,
QueueNodeSucceededEvent, QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent, QueuePingEvent,
QueueRetrieverResourcesEvent, QueueRetrieverResourcesEvent,
QueueStopEvent, QueueStopEvent,
@ -33,28 +31,34 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse, ChatbotAppBlockingResponse,
ChatbotAppStreamResponse, ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse, ErrorStreamResponse,
MessageAudioEndStreamResponse, MessageAudioEndStreamResponse,
MessageAudioStreamResponse, MessageAudioStreamResponse,
MessageEndStreamResponse, MessageEndStreamResponse,
StreamResponse, StreamResponse,
WorkflowTaskState,
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import Conversation, EndUser, Message from models.model import Conversation, EndUser, Message
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecution,
WorkflowRunStatus, WorkflowRunStatus,
) )
@ -65,22 +69,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_task_state: AdvancedChatTaskState
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity _application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow _workflow: Workflow
_user: Union[Account, EndUser] _user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__( def __init__(
self, self, application_generate_entity: AdvancedChatAppGenerateEntity,
application_generate_entity: AdvancedChatAppGenerateEntity, workflow: Workflow,
workflow: Workflow, queue_manager: AppQueueManager,
queue_manager: AppQueueManager, conversation: Conversation,
conversation: Conversation, message: Message,
message: Message, user: Union[Account, EndUser],
user: Union[Account, EndUser], stream: bool,
stream: bool,
) -> None: ) -> None:
""" """
Initialize AdvancedChatAppGenerateTaskPipeline. Initialize AdvancedChatAppGenerateTaskPipeline.
@ -102,6 +106,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow self._workflow = workflow
self._conversation = conversation self._conversation = conversation
self._message = message self._message = message
# Deprecated
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query, SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
@ -109,8 +114,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
SystemVariableKey.USER_ID: user_id, SystemVariableKey.USER_ID: user_id,
} }
self._task_state = WorkflowTaskState() self._task_state = AdvancedChatTaskState(
usage=LLMUsage.empty_usage()
)
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None self._conversation_name_generate_thread = None
def process(self): def process(self):
@ -124,11 +133,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._application_generate_entity.query self._conversation,
self._application_generate_entity.query
) )
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream: if self._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
else: else:
@ -145,7 +156,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(stream_response, MessageEndStreamResponse): elif isinstance(stream_response, MessageEndStreamResponse):
extras = {} extras = {}
if stream_response.metadata: if stream_response.metadata:
extras["metadata"] = stream_response.metadata extras['metadata'] = stream_response.metadata
return ChatbotAppBlockingResponse( return ChatbotAppBlockingResponse(
task_id=stream_response.task_id, task_id=stream_response.task_id,
@ -156,17 +167,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
message_id=self._message.id, message_id=self._message.id,
answer=self._task_state.answer, answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
**extras, **extras
), )
) )
else: else:
continue continue
raise Exception("Queue listening stopped unexpectedly.") raise Exception('Queue listening stopped unexpectedly.')
def _to_stream_response( def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:
""" """
To stream response. To stream response.
:return: :return:
@ -176,35 +185,31 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation.id, conversation_id=self._conversation.id,
message_id=self._message.id, message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()), created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response, stream_response=stream_response
) )
def _listen_audio_msg(self, publisher, task_id: str): def _listenAudioMsg(self, publisher, task_id: str):
if not publisher: if not publisher:
return None return None
audio_msg: AudioTrunk = publisher.check_and_get_audio() audio_msg: AudioTrunk = publisher.checkAndGetAudio()
if audio_msg and audio_msg.status != "finish": if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
def _wrapper_process_stream_response( def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
self, trace_manager: Optional[TraceQueueManager] = None Generator[StreamResponse, None, None]:
) -> Generator[StreamResponse, None, None]:
tts_publisher = None publisher = None
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict features_dict = self._workflow.features_dict
if ( if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
features_dict.get("text_to_speech") 'text_to_speech'].get('autoPlay') == 'enabled':
and features_dict["text_to_speech"].get("enabled") publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
and features_dict["text_to_speech"].get("autoPlay") == "enabled" for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) audio_response = self._listenAudioMsg(publisher, task_id=task_id)
if audio_response: if audio_response:
yield audio_response yield audio_response
else: else:
@ -215,9 +220,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# timeout # timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try: try:
if not tts_publisher: if not publisher:
break break
audio_trunk = tts_publisher.check_and_get_audio() audio_trunk = publisher.checkAndGetAudio()
if audio_trunk is None: if audio_trunk is None:
# release cpu # release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -231,38 +236,38 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
break break
yield MessageAudioEndStreamResponse(audio="", task_id=task_id) yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
def _process_stream_response( def _process_stream_response(
self, self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None, publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
""" """
Process stream response. Process stream response.
:return: :return:
""" """
# init fake graph runtime state for message in self._queue_manager.listen():
graph_runtime_state = None if (message.event
workflow_run = None and getattr(message.event, 'metadata', None)
and message.event.metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
elif (hasattr(message.event, 'execution_metadata')
and message.event.execution_metadata
and message.event.execution_metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
event = message.event
for queue_message in self._queue_manager.listen(): if isinstance(event, QueueErrorEvent):
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message) err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err) yield self._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state workflow_run = self._handle_workflow_start()
graph_runtime_state = event.graph_runtime_state
# init workflow run self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
workflow_run = self._handle_workflow_run_start()
self._refetch_message()
self._message.workflow_run_id = workflow_run.id self._message.workflow_run_id = workflow_run.id
db.session.commit() db.session.commit()
@ -270,231 +275,137 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
db.session.close() db.session.close()
yield self._workflow_start_to_stream_response( yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
) )
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run: workflow_node_execution = self._handle_node_start(event)
raise Exception("Workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) # search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# reset current route position to 0
self._task_state.current_stream_generate_state.current_route_position = 0
response = self._workflow_node_start_to_stream_response( # generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution
) )
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
if response: # stream outputs when node finished
yield response generator = self._generate_stream_outputs_when_node_finished()
elif isinstance(event, QueueNodeSucceededEvent): if generator:
workflow_node_execution = self._handle_workflow_node_execution_success(event) yield from generator
response = self._workflow_node_finish_to_stream_response( yield self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution
) )
if response: if isinstance(event, QueueNodeFailedEvent):
yield response yield from self._handle_iteration_exception(
elif isinstance(event, QueueNodeFailedEvent): task_id=self._application_generate_entity.task_id,
workflow_node_execution = self._handle_workflow_node_execution_failed(event) error=f'Child node failed: {event.error}'
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if event.outputs else None,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
if workflow_run and graph_runtime_state:
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
conversation_id=self._conversation.id,
trace_manager=trace_manager,
) )
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, conversation_id=self._conversation.id, trace_manager=trace_manager
)
if workflow_run:
yield self._workflow_finish_to_stream_response( yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
) )
if workflow_run.status == WorkflowRunStatus.FAILED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
if isinstance(event, QueueStopEvent):
# Save message
self._save_message()
yield self._message_end_to_stream_response()
break
else:
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message # Save message
self._save_message(graph_runtime_state=graph_runtime_state) self._save_message()
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event) self._handle_retriever_resources(event)
self._refetch_message()
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event) self._handle_annotation_reply(event)
self._refetch_message()
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
if delta_text is None: if delta_text is None:
continue continue
if not self._is_stream_out_support(
event=event
):
continue
# handle output moderation chunk # handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text) should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer: if should_direct_answer:
continue continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response( yield self._message_to_stream_response(delta_text, self._message.id)
answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueuePingEvent):
if not graph_runtime_state: yield self._ping_stream_response()
raise Exception("Graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
else: else:
continue continue
if publisher:
# publish None when task finished publisher.publish(None)
if tts_publisher:
tts_publisher.publish(None)
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self) -> None:
""" """
Save message. Save message.
:return: :return:
""" """
self._refetch_message() self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._message.answer = self._task_state.answer self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = ( self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None if self._task_state.metadata else None
)
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
self._message.message_tokens = usage.prompt_tokens self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit self._message.message_price_unit = usage.prompt_price_unit
@ -511,7 +422,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation, conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None, is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras, extras=self._application_generate_entity.extras
) )
def _message_end_to_stream_response(self) -> MessageEndStreamResponse: def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@ -521,15 +432,331 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
extras = {} extras = {}
if self._task_state.metadata: if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.copy() extras['metadata'] = self._task_state.metadata
if "annotation_reply" in extras["metadata"]:
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
) )
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
answer_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.ANSWER.value
]
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_config in answer_node_configs:
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
# check if it's the first node in the iteration
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
if not target_node:
return []
node_iteration_id = target_node.get('data', {}).get('iteration_id')
# get iteration start node id
for node in nodes:
if node.get('id') == node_iteration_id:
if node.get('data', {}).get('start_node_id') == target_node_id:
return [target_node_id]
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:
]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
if should_direct_answer:
continue
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
break
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
self._task_state.current_stream_generate_state.current_route_position += 1
continue
route_chunk_node_id = value_selector[0]
if route_chunk_node_id == 'sys':
# system variable
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
continue
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
iterator = iteration_state.inputs
if not iterator:
continue
iterator_selector = iterator.get('iterator_selector', [])
if value_selector[1] == 'index':
value = iteration_state.current_index
elif value_selector[1] == 'item':
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
iterator_selector
) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
break
latest_node_execution_info = self._task_state.latest_node_execution_info
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
# get route chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
).first()
outputs = route_chunk_node_execution.outputs_dict
# get value from outputs
value = None
for key in value_selector[1:]:
if not value:
value = outputs.get(key) if outputs else None
else:
value = value.get(key)
if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
if text:
self._task_state.answer += text
yield self._message_to_stream_response(text, self._message.id)
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return True
if 'node_id' not in event.metadata:
return True
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
route_chunk = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position]
if route_chunk.type != 'var':
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
return False
return True
def _handle_output_moderation_chunk(self, text: str) -> bool: def _handle_output_moderation_chunk(self, text: str) -> bool:
""" """
Handle output moderation chunk. Handle output moderation chunk.
@ -541,23 +768,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# stop subscribe new token when output moderation should direct output # stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output() self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish( self._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
) )
self._queue_manager.publish( self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
) )
return True return True
else: else:
self._output_moderation_handler.append_new_token(text) self._output_moderation_handler.append_new_token(text)
return False return False
def _refetch_message(self) -> None:
"""
Refetch message.
:return:
"""
message = db.session.query(Message).filter(Message.id == self._message.id).first()
if message:
self._message = message

View File

@ -0,0 +1,203 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager._publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager._publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self._queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@ -28,19 +28,15 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
""" """
Agent Chatbot App Config Entity. Agent Chatbot App Config Entity.
""" """
agent: Optional[AgentEntity] = None agent: Optional[AgentEntity] = None
class AgentChatAppConfigManager(BaseAppConfigManager): class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_app_config( def get_app_config(cls, app_model: App,
cls, app_model_config: AppModelConfig,
app_model: App, conversation: Optional[Conversation] = None,
app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> AgentChatAppConfig:
""" """
Convert app model config to agent chat app config Convert app model config to agent chat app config
:param app_model: app model :param app_model: app model
@ -70,12 +66,22 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from, app_model_config_from=config_from,
app_model_config_id=app_model_config.id, app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict, app_model_config_dict=config_dict,
model=ModelConfigManager.convert(config=config_dict), model=ModelConfigManager.convert(
prompt_template=PromptTemplateConfigManager.convert(config=config_dict), config=config_dict
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), ),
dataset=DatasetConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(
agent=AgentConfigManager.convert(config=config_dict), config=config_dict
additional_features=cls.convert_features(config_dict, app_mode), ),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
agent=AgentConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
) )
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert( app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@ -122,8 +128,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer # suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults( config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config config)
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# speech_to_text # speech_to_text
@ -140,15 +145,13 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# dataset configs # dataset configs
# dataset_query_variable # dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults( config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
tenant_id, app_mode, config config)
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# moderation validation # moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults( config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
tenant_id, config config)
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys)) related_config_keys = list(set(related_config_keys))
@ -167,7 +170,10 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
:param config: app model config args :param config: app model config args
""" """
if not config.get("agent_mode"): if not config.get("agent_mode"):
config["agent_mode"] = {"enabled": False, "tools": []} config["agent_mode"] = {
"enabled": False,
"tools": []
}
if not isinstance(config["agent_mode"], dict): if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type") raise ValueError("agent_mode must be of object type")
@ -181,9 +187,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config["agent_mode"].get("strategy"): if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [ if config["agent_mode"]["strategy"] not in [member.value for member in
member.value for member in list(PlanningStrategy.__members__.values()) list(PlanningStrategy.__members__.values())]:
]:
raise ValueError("strategy in agent_mode must be in the specified strategy list") raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"): if not config["agent_mode"].get("tools"):
@ -205,7 +210,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode.tools must be of boolean type") raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "dataset": if key == "dataset":
if "id" not in tool_item: if 'id' not in tool_item:
raise ValueError("id is required in dataset") raise ValueError("id is required in dataset")
try: try:

View File

@ -3,7 +3,7 @@ import os
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Literal, Union, overload from typing import Any, Union
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
@ -28,29 +28,12 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator): class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload def generate(self, app_model: App,
def generate( user: Union[Account, EndUser],
self, args: Any,
app_model: App, invoke_from: InvokeFrom,
user: Union[Account, EndUser], stream: bool = True) \
args: dict, -> Union[dict, Generator[dict, None, None]]:
invoke_from: InvokeFrom,
stream: Literal[True] = True,
) -> Generator[dict, None, None]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[dict, None, None]]:
""" """
Generate App response. Generate App response.
@ -61,48 +44,60 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream :param stream: is stream
""" """
if not stream: if not stream:
raise ValueError("Agent Chat App does not support blocking mode") raise ValueError('Agent Chat App does not support blocking mode')
if not args.get("query"): if not args.get('query'):
raise ValueError("query is required") raise ValueError('query is required')
query = args["query"] query = args['query']
if not isinstance(query, str): if not isinstance(query, str):
raise ValueError("query must be a string") raise ValueError('query must be a string')
query = query.replace("\x00", "") query = query.replace('\x00', '')
inputs = args["inputs"] inputs = args['inputs']
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)} extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
# get conversation # get conversation
conversation = None conversation = None
if args.get("conversation_id"): if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user) conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
# get app model config # get app model config
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation) app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
# validate override model config # validate override model config
override_model_config_dict = None override_model_config_dict = None
if args.get("model_config"): if args.get('model_config'):
if invoke_from != InvokeFrom.DEBUGGER: if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError("Only in App debug mode can override model config") raise ValueError('Only in App debug mode can override model config')
# validate config # validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate( override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=args.get("model_config") tenant_id=app_model.tenant_id,
config=args.get('model_config')
) )
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True} override_model_config_dict["retriever_resource"] = {
"enabled": True
}
# parse files # parse files
files = args["files"] if args.get("files") else [] files = args['files'] if args.get('files') else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id) message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user) file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
else: else:
file_objs = [] file_objs = []
@ -111,7 +106,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model, app_model=app_model,
app_model_config=app_model_config, app_model_config=app_model_config,
conversation=conversation, conversation=conversation,
override_config_dict=override_model_config_dict, override_config_dict=override_model_config_dict
) )
# get tracing instance # get tracing instance
@ -132,11 +127,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,
call_depth=0, call_depth=0,
trace_manager=trace_manager, trace_manager=trace_manager
) )
# init generate records # init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation) (
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager # init queue manager
queue_manager = MessageBasedAppQueueManager( queue_manager = MessageBasedAppQueueManager(
@ -145,20 +143,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id, conversation_id=conversation.id,
app_mode=conversation.mode, app_mode=conversation.mode,
message_id=message.id, message_id=message.id
) )
# new thread # new thread
worker_thread = threading.Thread( worker_thread = threading.Thread(target=self._generate_worker, kwargs={
target=self._generate_worker, 'flask_app': current_app._get_current_object(),
kwargs={ 'application_generate_entity': application_generate_entity,
"flask_app": current_app._get_current_object(), 'queue_manager': queue_manager,
"application_generate_entity": application_generate_entity, 'conversation_id': conversation.id,
"queue_manager": queue_manager, 'message_id': message.id,
"conversation_id": conversation.id, })
"message_id": message.id,
},
)
worker_thread.start() worker_thread.start()
@ -172,11 +167,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
stream=stream, stream=stream,
) )
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) return AgentChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
def _generate_worker( def _generate_worker(
self, self, flask_app: Flask,
flask_app: Flask,
application_generate_entity: AgentChatAppGenerateEntity, application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
@ -205,17 +202,18 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
) )
except GenerateTaskStoppedError: except GenerateTaskStoppedException:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
) )
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true": if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
logger.exception("Error when generating") logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:

View File

@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought from models.model import App, Conversation, Message, MessageAgentThought
@ -30,8 +30,7 @@ class AgentChatAppRunner(AppRunner):
""" """
def run( def run(
self, self, application_generate_entity: AgentChatAppGenerateEntity,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
@ -66,7 +65,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template, prompt_template_entity=app_config.prompt_template,
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query
) )
memory = None memory = None
@ -74,10 +73,13 @@ class AgentChatAppRunner(AppRunner):
# get memory of conversation (read-only) # get memory of conversation (read-only)
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model, model=application_generate_entity.model_conf.model
) )
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
@ -89,7 +91,7 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
memory=memory, memory=memory
) )
# moderation # moderation
@ -101,15 +103,15 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
message_id=message.id, message_id=message.id
) )
except ModerationError as e: except ModerationException as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=str(e), text=str(e),
stream=application_generate_entity.stream, stream=application_generate_entity.stream
) )
return return
@ -120,13 +122,13 @@ class AgentChatAppRunner(AppRunner):
message=message, message=message,
query=query, query=query,
user_id=application_generate_entity.user_id, user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from
) )
if annotation_reply: if annotation_reply:
queue_manager.publish( queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER, PublishFrom.APPLICATION_MANAGER
) )
self.direct_output( self.direct_output(
@ -134,7 +136,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
text=annotation_reply.content, text=annotation_reply.content,
stream=application_generate_entity.stream, stream=application_generate_entity.stream
) )
return return
@ -146,7 +148,7 @@ class AgentChatAppRunner(AppRunner):
app_id=app_record.id, app_id=app_record.id,
external_data_tools=external_data_tools, external_data_tools=external_data_tools,
inputs=inputs, inputs=inputs,
query=query, query=query
) )
# reorganize all inputs and template to prompt messages # reorganize all inputs and template to prompt messages
@ -159,14 +161,14 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs, inputs=inputs,
files=files, files=files,
query=query, query=query,
memory=memory, memory=memory
) )
# check hosting moderation # check hosting moderation
hosting_moderation_result = self.check_hosting_moderation( hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
prompt_messages=prompt_messages, prompt_messages=prompt_messages
) )
if hosting_moderation_result: if hosting_moderation_result:
@ -175,9 +177,9 @@ class AgentChatAppRunner(AppRunner):
agent_entity = app_config.agent agent_entity = app_config.agent
# load tool variables # load tool variables
tool_conversation_variables = self._load_tool_variables( tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id user_id=application_generate_entity.user_id,
) tenant_id=app_config.tenant_id)
# convert db variables to tool variables # convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
@ -185,7 +187,7 @@ class AgentChatAppRunner(AppRunner):
# init model instance # init model instance
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model, model=application_generate_entity.model_conf.model
) )
prompt_message, _ = self.organize_prompt_messages( prompt_message, _ = self.organize_prompt_messages(
app_record=app_record, app_record=app_record,
@ -236,7 +238,7 @@ class AgentChatAppRunner(AppRunner):
prompt_messages=prompt_message, prompt_messages=prompt_message,
variables_pool=tool_variables, variables_pool=tool_variables,
db_variables=tool_conversation_variables, db_variables=tool_conversation_variables,
model_instance=model_instance, model_instance=model_instance
) )
invoke_result = runner.run( invoke_result = runner.run(
@ -250,21 +252,17 @@ class AgentChatAppRunner(AppRunner):
invoke_result=invoke_result, invoke_result=invoke_result,
queue_manager=queue_manager, queue_manager=queue_manager,
stream=application_generate_entity.stream, stream=application_generate_entity.stream,
agent=True, agent=True
) )
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables: def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
""" """
load tool variables from database load tool variables from database
""" """
tool_variables: ToolConversationVariables = ( tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
db.session.query(ToolConversationVariables) ToolConversationVariables.conversation_id == conversation_id,
.filter( ToolConversationVariables.tenant_id == tenant_id
ToolConversationVariables.conversation_id == conversation_id, ).first()
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
if tool_variables: if tool_variables:
# save tool variables to session, so that we can update it later # save tool variables to session, so that we can update it later
@ -275,40 +273,34 @@ class AgentChatAppRunner(AppRunner):
conversation_id=conversation_id, conversation_id=conversation_id,
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
variables_str="[]", variables_str='[]',
) )
db.session.add(tool_variables) db.session.add(tool_variables)
db.session.commit() db.session.commit()
return tool_variables return tool_variables
def _convert_db_variables_to_tool_variables( def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
""" """
convert db variables to tool variables convert db variables to tool variables
""" """
return ToolRuntimeVariablePool( return ToolRuntimeVariablePool(**{
**{ 'conversation_id': db_variables.conversation_id,
"conversation_id": db_variables.conversation_id, 'user_id': db_variables.user_id,
"user_id": db_variables.user_id, 'tenant_id': db_variables.tenant_id,
"tenant_id": db_variables.tenant_id, 'pool': db_variables.variables
"pool": db_variables.variables, })
}
)
def _get_usage_of_all_agent_thoughts( def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
self, model_config: ModelConfigWithCredentialsEntity, message: Message message: Message) -> LLMUsage:
) -> LLMUsage:
""" """
Get usage of all agent thoughts Get usage of all agent thoughts
:param model_config: model config :param model_config: model config
:param message: message :param message: message
:return: :return:
""" """
agent_thoughts = ( agent_thoughts = (db.session.query(MessageAgentThought)
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all() .filter(MessageAgentThought.message_id == message.id).all())
)
all_message_tokens = 0 all_message_tokens = 0
all_answer_tokens = 0 all_answer_tokens = 0
@ -320,5 +312,8 @@ class AgentChatAppRunner(AppRunner):
model_type_instance = cast(LargeLanguageModel, model_type_instance) model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage( return model_type_instance._calc_response_usage(
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens model_config.model,
model_config.credentials,
all_message_tokens,
all_answer_tokens
) )

View File

@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return: :return:
""" """
response = { response = {
"event": "message", 'event': 'message',
"task_id": blocking_response.task_id, 'task_id': blocking_response.task_id,
"id": blocking_response.data.id, 'id': blocking_response.data.id,
"message_id": blocking_response.data.message_id, 'message_id': blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id, 'conversation_id': blocking_response.data.conversation_id,
"mode": blocking_response.data.mode, 'mode': blocking_response.data.mode,
"answer": blocking_response.data.answer, 'answer': blocking_response.data.answer,
"metadata": blocking_response.data.metadata, 'metadata': blocking_response.data.metadata,
"created_at": blocking_response.data.created_at, 'created_at': blocking_response.data.created_at
} }
return response return response
@ -45,15 +45,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
""" """
response = cls.convert_blocking_full_response(blocking_response) response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {}) metadata = response.get('metadata', {})
response["metadata"] = cls._get_simple_metadata(metadata) response['metadata'] = cls._get_simple_metadata(metadata)
return response return response
@classmethod @classmethod
def convert_stream_full_response( def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] -> Generator[str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream full response. Convert stream full response.
:param stream_response: stream response :param stream_response: stream response
@ -64,14 +63,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield "ping" yield 'ping'
continue continue
response_chunk = { response_chunk = {
"event": sub_stream_response.event.value, 'event': sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, 'conversation_id': chunk.conversation_id,
"message_id": chunk.message_id, 'message_id': chunk.message_id,
"created_at": chunk.created_at, 'created_at': chunk.created_at
} }
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
@ -82,9 +81,8 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk) yield json.dumps(response_chunk)
@classmethod @classmethod
def convert_stream_simple_response( def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None] -> Generator[str, None, None]:
) -> Generator[str, None, None]:
""" """
Convert stream simple response. Convert stream simple response.
:param stream_response: stream response :param stream_response: stream response
@ -95,20 +93,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse): if isinstance(sub_stream_response, PingStreamResponse):
yield "ping" yield 'ping'
continue continue
response_chunk = { response_chunk = {
"event": sub_stream_response.event.value, 'event': sub_stream_response.event.value,
"conversation_id": chunk.conversation_id, 'conversation_id': chunk.conversation_id,
"message_id": chunk.message_id, 'message_id': chunk.message_id,
"created_at": chunk.created_at, 'created_at': chunk.created_at
} }
if isinstance(sub_stream_response, MessageEndStreamResponse): if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict() sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get("metadata", {}) metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict) response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse): if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)

Some files were not shown because too many files have changed in this diff Show More