Compare commits

..

2 Commits

Author SHA1 Message Date
Yi
d688bebb1a Merge branch 'main' into fix/note-node-zoom-issue 2024-08-19 15:13:33 +08:00
Yi
f2ad16cec5 fix: note editor zoom issue 2024-08-19 15:13:02 +08:00
1968 changed files with 47888 additions and 69594 deletions

View File

@ -20,7 +20,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v45 uses: tj-actions/changed-files@v44
with: with:
files: api/** files: api/**
@ -66,7 +66,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v45 uses: tj-actions/changed-files@v44
with: with:
files: web/** files: web/**
@ -97,7 +97,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v45 uses: tj-actions/changed-files@v44
with: with:
files: | files: |
**.sh **.sh
@ -107,7 +107,7 @@ jobs:
dev/** dev/**
- name: Super-linter - name: Super-linter
uses: super-linter/super-linter/slim@v7 uses: super-linter/super-linter/slim@v6
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
env: env:
BASH_SEVERITY: warning BASH_SEVERITY: warning

View File

@ -1,54 +0,0 @@
name: Check i18n Files and Create PR
on:
pull_request:
types: [closed]
branches: [main]
jobs:
check-and-update:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
- name: Check for file changes in i18n/en-US
id: check_files
run: |
recent_commit_sha=$(git rev-parse HEAD)
second_recent_commit_sha=$(git rev-parse HEAD~1)
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v2
with:
node-version: 'lts/*'
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.
branch: chore/automated-i18n-updates

1
.gitignore vendored
View File

@ -178,4 +178,3 @@ pyrightconfig.json
api/.vscode api/.vscode
.idea/ .idea/
.vscode

View File

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@ -12,6 +12,5 @@
</component> </component>
<component name="VcsDirectoryMappings"> <component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" /> <mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component> </component>
</project> </project>

View File

@ -5,8 +5,8 @@
"name": "Python: Flask", "name": "Python: Flask",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"python": "${workspaceFolder}/.venv/bin/python", "python": "${workspaceFolder}/api/.venv/bin/python",
"cwd": "${workspaceFolder}", "cwd": "${workspaceFolder}/api",
"envFile": ".env", "envFile": ".env",
"module": "flask", "module": "flask",
"justMyCode": true, "justMyCode": true,
@ -18,15 +18,15 @@
"args": [ "args": [
"run", "run",
"--host=0.0.0.0", "--host=0.0.0.0",
"--port=5001" "--port=5001",
] ]
}, },
{ {
"name": "Python: Celery", "name": "Python: Celery",
"type": "debugpy", "type": "debugpy",
"request": "launch", "request": "launch",
"python": "${workspaceFolder}/.venv/bin/python", "python": "${workspaceFolder}/api/.venv/bin/python",
"cwd": "${workspaceFolder}", "cwd": "${workspaceFolder}/api",
"module": "celery", "module": "celery",
"justMyCode": true, "justMyCode": true,
"envFile": ".env", "envFile": ".env",

View File

@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
## Before you jump in ## Before you jump in
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: [Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
### Feature requests: ### Feature requests:

View File

@ -8,7 +8,7 @@
## 在开始之前 ## 在开始之前
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类: [查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
### 功能请求: ### 功能请求:

View File

@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
## 飛び込む前に ## 飛び込む前に
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。 [既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
### 機能リクエスト ### 機能リクエスト

View File

@ -1,156 +0,0 @@
Thật tuyệt vời khi bạn muốn đóng góp cho Dify! Chúng tôi rất mong chờ được thấy những gì bạn sẽ làm. Là một startup với nguồn nhân lực và tài chính hạn chế, chúng tôi có tham vọng lớn là thiết kế quy trình trực quan nhất để xây dựng và quản lý các ứng dụng LLM. Mọi sự giúp đỡ từ cộng đồng đều rất quý giá đối với chúng tôi.
Chúng tôi cần linh hoạt và làm việc nhanh chóng, nhưng đồng thời cũng muốn đảm bảo các cộng tác viên như bạn có trải nghiệm đóng góp thuận lợi nhất có thể. Chúng tôi đã tạo ra hướng dẫn đóng góp này nhằm giúp bạn làm quen với codebase và cách chúng tôi làm việc với các cộng tác viên, để bạn có thể nhanh chóng bắt tay vào phần thú vị.
Hướng dẫn này, cũng như bản thân Dify, đang trong quá trình cải tiến liên tục. Chúng tôi rất cảm kích sự thông cảm của bạn nếu đôi khi nó không theo kịp dự án thực tế, và chúng tôi luôn hoan nghênh mọi phản hồi để cải thiện.
Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [Thỏa thuận Cấp phép và Đóng góp](./LICENSE) ngắn gọn của chúng tôi. Cộng đồng cũng tuân thủ [quy tắc ứng xử](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md).
## Trước khi bắt đầu
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
### Yêu cầu tính năng:
* Nếu bạn đang tạo một yêu cầu tính năng mới, chúng tôi muốn bạn giải thích tính năng đề xuất sẽ đạt được điều gì và cung cấp càng nhiều thông tin chi tiết càng tốt. [@perzeusss](https://github.com/perzeuss) đã tạo một [Trợ lý Yêu cầu Tính năng](https://udify.app/chat/MK2kVSnw1gakVwMX) rất hữu ích để giúp bạn soạn thảo nhu cầu của mình. Hãy thử dùng nó nhé.
* Nếu bạn muốn chọn một vấn đề từ danh sách hiện có, chỉ cần để lại bình luận dưới vấn đề đó nói rằng bạn sẽ làm.
Một thành viên trong nhóm làm việc trong lĩnh vực liên quan sẽ được thông báo. Nếu mọi thứ ổn, họ sẽ cho phép bạn bắt đầu code. Chúng tôi yêu cầu bạn chờ đợi cho đến lúc đó trước khi bắt tay vào làm tính năng, để không lãng phí công sức của bạn nếu chúng tôi đề xuất thay đổi.
Tùy thuộc vào lĩnh vực mà tính năng đề xuất thuộc về, bạn có thể nói chuyện với các thành viên khác nhau trong nhóm. Dưới đây là danh sách các lĩnh vực mà các thành viên trong nhóm chúng tôi đang làm việc hiện tại:
| Thành viên | Phạm vi |
| ------------------------------------------------------------ | ---------------------------------------------------- |
| [@yeuoly](https://github.com/Yeuoly) | Thiết kế kiến trúc Agents |
| [@jyong](https://github.com/JohnJyong) | Thiết kế quy trình RAG |
| [@GarfieldDai](https://github.com/GarfieldDai) | Xây dựng quy trình làm việc |
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | Làm cho giao diện người dùng dễ sử dụng |
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | Trải nghiệm nhà phát triển, đầu mối liên hệ cho mọi vấn đề |
| [@takatost](https://github.com/takatost) | Định hướng và kiến trúc tổng thể sản phẩm |
Cách chúng tôi ưu tiên:
| Loại tính năng | Mức độ ưu tiên |
| ------------------------------------------------------------ | -------------- |
| Tính năng ưu tiên cao được gắn nhãn bởi thành viên trong nhóm | Ưu tiên cao |
| Yêu cầu tính năng phổ biến từ [bảng phản hồi cộng đồng](https://github.com/langgenius/dify/discussions/categories/feedbacks) của chúng tôi | Ưu tiên trung bình |
| Tính năng không quan trọng và cải tiến nhỏ | Ưu tiên thấp |
| Có giá trị nhưng không cấp bách | Tính năng tương lai |
### Những vấn đề khác (ví dụ: báo cáo lỗi, tối ưu hiệu suất, sửa lỗi chính tả):
* Bắt đầu code ngay lập tức.
Cách chúng tôi ưu tiên:
| Loại vấn đề | Mức độ ưu tiên |
| ------------------------------------------------------------ | -------------- |
| Lỗi trong các chức năng chính (không thể đăng nhập, ứng dụng không hoạt động, lỗ hổng bảo mật) | Nghiêm trọng |
| Lỗi không quan trọng, cải thiện hiệu suất | Ưu tiên trung bình |
| Sửa lỗi nhỏ (lỗi chính tả, giao diện người dùng gây nhầm lẫn nhưng vẫn hoạt động) | Ưu tiên thấp |
## Cài đặt
Dưới đây là các bước để thiết lập Dify cho việc phát triển:
### 1. Fork repository này
### 2. Clone repository
Clone repository đã fork từ terminal của bạn:
```
git clone git@github.com:<tên_người_dùng_github>/dify.git
```
### 3. Kiểm tra các phụ thuộc
Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đã được cài đặt trên hệ thống của bạn:
- [Docker](https://www.docker.com/)
- [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) phiên bản 3.10.x
### 4. Cài đặt
Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt.
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
### 5. Truy cập Dify trong trình duyệt của bạn
Để xác nhận cài đặt của bạn, hãy truy cập [http://localhost:3000](http://localhost:3000) (địa chỉ mặc định, hoặc URL và cổng bạn đã cấu hình) trong trình duyệt. Bạn sẽ thấy Dify đang chạy.
## Phát triển
Nếu bạn đang thêm một nhà cung cấp mô hình, [hướng dẫn này](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md) dành cho bạn.
Nếu bạn đang thêm một nhà cung cấp công cụ cho Agent hoặc Workflow, [hướng dẫn này](./api/core/tools/README.md) dành cho bạn.
Để giúp bạn nhanh chóng định hướng phần đóng góp của mình, dưới đây là một bản phác thảo ngắn gọn về cấu trúc backend & frontend của Dify:
### Backend
Backend của Dify được viết bằng Python sử dụng [Flask](https://flask.palletsprojects.com/en/3.0.x/). Nó sử dụng [SQLAlchemy](https://www.sqlalchemy.org/) cho ORM và [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) cho hàng đợi tác vụ. Logic xác thực được thực hiện thông qua Flask-login.
```
[api/]
├── constants // Các cài đặt hằng số được sử dụng trong toàn bộ codebase.
├── controllers // Định nghĩa các route API và logic xử lý yêu cầu.
├── core // Điều phối ứng dụng cốt lõi, tích hợp mô hình và công cụ.
├── docker // Cấu hình liên quan đến Docker & containerization.
├── events // Xử lý và xử lý sự kiện
├── extensions // Mở rộng với các framework/nền tảng bên thứ 3.
├── fields // Định nghĩa trường cho serialization/marshalling.
├── libs // Thư viện và tiện ích có thể tái sử dụng.
├── migrations // Script cho việc di chuyển cơ sở dữ liệu.
├── models // Mô hình cơ sở dữ liệu & định nghĩa schema.
├── services // Xác định logic nghiệp vụ.
├── storage // Lưu trữ khóa riêng tư.
├── tasks // Xử lý các tác vụ bất đồng bộ và công việc nền.
└── tests
```
### Frontend
Website được khởi tạo trên boilerplate [Next.js](https://nextjs.org/) bằng Typescript và sử dụng [Tailwind CSS](https://tailwindcss.com/) cho styling. [React-i18next](https://react.i18next.com/) được sử dụng cho việc quốc tế hóa.
```
[web/]
├── app // layouts, pages và components
│ ├── (commonLayout) // layout chung được sử dụng trong toàn bộ ứng dụng
│ ├── (shareLayout) // layouts được chia sẻ cụ thể cho các phiên dựa trên token
│ ├── activate // trang kích hoạt
│ ├── components // được chia sẻ bởi các trang và layouts
│ ├── install // trang cài đặt
│ ├── signin // trang đăng nhập
│ └── styles // styles được chia sẻ toàn cục
├── assets // Tài nguyên tĩnh
├── bin // scripts chạy ở bước build
├── config // cài đặt và tùy chọn có thể điều chỉnh
├── context // contexts được chia sẻ bởi các phần khác nhau của ứng dụng
├── dictionaries // File dịch cho từng ngôn ngữ
├── docker // cấu hình container
├── hooks // Hooks có thể tái sử dụng
├── i18n // Cấu hình quốc tế hóa
├── models // mô tả các mô hình dữ liệu & hình dạng của phản hồi API
├── public // tài nguyên meta như favicon
├── service // xác định hình dạng của các hành động API
├── test
├── types // mô tả các tham số hàm và giá trị trả về
└── utils // Các hàm tiện ích được chia sẻ
```
## Gửi PR của bạn
Cuối cùng, đã đến lúc mở một pull request (PR) đến repository của chúng tôi. Đối với các tính năng lớn, chúng tôi sẽ merge chúng vào nhánh `deploy/dev` để kiểm tra trước khi đưa vào nhánh `main`. Nếu bạn gặp vấn đề như xung đột merge hoặc không biết cách mở pull request, hãy xem [hướng dẫn về pull request của GitHub](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giới thiệu là một người đóng góp trong [README](https://github.com/langgenius/dify/blob/main/README.md) của chúng tôi.
## Nhận trợ giúp
Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng.

View File

@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer: 1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment. a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations. - Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components. b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.

View File

@ -39,7 +39,7 @@ DB_DATABASE=dify
# Storage configuration # Storage configuration
# use for store upload files, private keys... # use for store upload files, private keys...
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos # storage type: local, s3, azure-blob, google-storage
STORAGE_TYPE=local STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false S3_USE_AWS_MANAGED_IAM=false
@ -60,8 +60,7 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
ALIYUN_OSS_ENDPOINT=your-endpoint ALIYUN_OSS_ENDPOINT=your-endpoint
ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration # Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
@ -73,12 +72,6 @@ TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme TENCENT_COS_SCHEME=your-scheme
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
# OCI Storage configuration # OCI Storage configuration
OCI_ENDPOINT=your-endpoint OCI_ENDPOINT=your-endpoint
OCI_BUCKET_NAME=your-bucket-name OCI_BUCKET_NAME=your-bucket-name
@ -86,13 +79,6 @@ OCI_ACCESS_KEY=your-access-key
OCI_SECRET_KEY=your-secret-key OCI_SECRET_KEY=your-secret-key
OCI_REGION=your-region OCI_REGION=your-region
# Volcengine tos Storage configuration
VOLCENGINE_TOS_ENDPOINT=your-endpoint
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
VOLCENGINE_TOS_REGION=your-region
# CORS configuration # CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
@ -114,10 +100,11 @@ QDRANT_GRPC_ENABLED=false
QDRANT_GRPC_PORT=6334 QDRANT_GRPC_PORT=6334
# Milvus configuration # Milvus configuration
MILVUS_URI=http://127.0.0.1:19530 MILVUS_HOST=127.0.0.1
MILVUS_TOKEN= MILVUS_PORT=19530
MILVUS_USER=root MILVUS_USER=root
MILVUS_PASSWORD=Milvus MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# MyScale configuration # MyScale configuration
MYSCALE_HOST=127.0.0.1 MYSCALE_HOST=127.0.0.1
@ -260,8 +247,8 @@ API_TOOL_DEFAULT_READ_TIMEOUT=60
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300 HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
HTTP_REQUEST_MAX_READ_TIMEOUT=600 HTTP_REQUEST_MAX_READ_TIMEOUT=600
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600 HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB
# Log file path # Log file path
LOG_FILE= LOG_FILE=
@ -280,13 +267,4 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration # Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1 CELERY_BEAT_SCHEDULER_TIME=1
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES=

View File

@ -5,10 +5,6 @@ WORKDIR /app/api
# Install Poetry # Install Poetry
ENV POETRY_VERSION=1.8.3 ENV POETRY_VERSION=1.8.3
# if you located in China, you can use aliyun mirror to speed up
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
RUN pip install --no-cache-dir poetry==${POETRY_VERSION} RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
# Configure Poetry # Configure Poetry
@ -20,9 +16,6 @@ ENV POETRY_REQUESTS_TIMEOUT=15
FROM base AS packages FROM base AS packages
# if you located in China, you can use aliyun mirror to speed up
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
@ -50,12 +43,10 @@ WORKDIR /app/api
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \ && apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# if you located in China, you can use aliyun mirror to speed up
# && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \ && apt-get update \
# For Security # For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \ && apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get autoremove -y \ && apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
@ -65,7 +56,7 @@ COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data # Download nltk data
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')" RUN python -c "import nltk; nltk.download('punkt')"
# Copy source code # Copy source code
COPY . /app/api/ COPY . /app/api/

View File

@ -411,8 +411,7 @@ def migrate_knowledge_vector_database():
try: try:
click.echo( click.echo(
click.style( click.style(
f"Start to created vector index with {len(documents)} documents of {segments_count}" f"Start to created vector index with {len(documents)} documents of {segments_count} segments for dataset {dataset.id}.",
f" segments for dataset {dataset.id}.",
fg="green", fg="green",
) )
) )
@ -560,9 +559,8 @@ def add_qdrant_doc_id_index(field: str):
@click.command("create-tenant", help="Create account and tenant.") @click.command("create-tenant", help="Create account and tenant.")
@click.option("--email", prompt=True, help="The email address of the tenant account.") @click.option("--email", prompt=True, help="The email address of the tenant account.")
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
@click.option("--language", prompt=True, help="Account language, default: en-US.") @click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): def create_tenant(email: str, language: Optional[str] = None):
""" """
Create tenant account Create tenant account
""" """
@ -582,15 +580,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
if language not in languages: if language not in languages:
language = "en-US" language = "en-US"
name = name.strip()
# generate random password # generate random password
new_password = secrets.token_urlsafe(16) new_password = secrets.token_urlsafe(16)
# register account # register account
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
TenantService.create_owner_tenant_if_not_exist(account, name) TenantService.create_owner_tenant_if_not_exist(account)
click.echo( click.echo(
click.style( click.style(

View File

@ -1,3 +1,3 @@
from .app_config import DifyConfig from .app_config import DifyConfig
dify_config = DifyConfig() dify_config = DifyConfig()

View File

@ -1,3 +1,4 @@
from pydantic import Field, computed_field
from pydantic_settings import SettingsConfigDict from pydantic_settings import SettingsConfigDict
from configs.deploy import DeploymentConfig from configs.deploy import DeploymentConfig
@ -23,16 +24,42 @@ class DifyConfig(
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.** # **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig, EnterpriseFeatureConfig,
): ):
DEBUG: bool = Field(default=False, description='whether to enable debug mode.')
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
# read from dotenv format config file # read from dotenv format config file
env_file=".env", env_file='.env',
env_file_encoding="utf-8", env_file_encoding='utf-8',
frozen=True, frozen=True,
# ignore extra attributes # ignore extra attributes
extra="ignore", extra='ignore',
) )
# Before adding any config, CODE_MAX_NUMBER: int = 9223372036854775807
# please consider to arrange it in the proper config group of existed or added CODE_MIN_NUMBER: int = -9223372036854775808
# for better readability and maintainability. CODE_MAX_STRING_LENGTH: int = 80000
# Thanks for your concentration and consideration. CODE_MAX_STRING_ARRAY_LENGTH: int = 30
CODE_MAX_OBJECT_ARRAY_LENGTH: int = 30
CODE_MAX_NUMBER_ARRAY_LENGTH: int = 1000
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = 300
HTTP_REQUEST_MAX_READ_TIMEOUT: int = 600
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = 600
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: int = 1024 * 1024 * 10
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_BINARY_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_BINARY_SIZE / 1024 / 1024:.2f}MB'
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: int = 1024 * 1024
@computed_field
def HTTP_REQUEST_NODE_READABLE_MAX_TEXT_SIZE(self) -> str:
return f'{self.HTTP_REQUEST_NODE_MAX_TEXT_SIZE / 1024 / 1024:.2f}MB'
SSRF_PROXY_HTTP_URL: str | None = None
SSRF_PROXY_HTTPS_URL: str | None = None
MODERATION_BUFFER_SIZE: int = Field(default=300, description='The buffer size for moderation.')
MAX_VARIABLE_SIZE: int = Field(default=5 * 1024, description='The maximum size of a variable. default is 5KB.')

View File

@ -6,28 +6,22 @@ class DeploymentConfig(BaseSettings):
""" """
Deployment configs Deployment configs
""" """
APPLICATION_NAME: str = Field( APPLICATION_NAME: str = Field(
description="application name", description='application name',
default="langgenius/dify", default='langgenius/dify',
)
DEBUG: bool = Field(
description="whether to enable debug mode.",
default=False,
) )
TESTING: bool = Field( TESTING: bool = Field(
description="", description='',
default=False, default=False,
) )
EDITION: str = Field( EDITION: str = Field(
description="deployment edition", description='deployment edition',
default="SELF_HOSTED", default='SELF_HOSTED',
) )
DEPLOY_ENV: str = Field( DEPLOY_ENV: str = Field(
description="deployment environment, default to PRODUCTION.", description='deployment environment, default to PRODUCTION.',
default="PRODUCTION", default='PRODUCTION',
) )

View File

@ -7,14 +7,13 @@ class EnterpriseFeatureConfig(BaseSettings):
Enterprise feature configs. Enterprise feature configs.
**Before using, please contact business@dify.ai by email to inquire about licensing matters.** **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
""" """
ENTERPRISE_ENABLED: bool = Field( ENTERPRISE_ENABLED: bool = Field(
description="whether to enable enterprise features." description='whether to enable enterprise features.'
"Before using, please contact business@dify.ai by email to inquire about licensing matters.", 'Before using, please contact business@dify.ai by email to inquire about licensing matters.',
default=False, default=False,
) )
CAN_REPLACE_LOGO: bool = Field( CAN_REPLACE_LOGO: bool = Field(
description="whether to allow replacing enterprise logo.", description='whether to allow replacing enterprise logo.',
default=False, default=False,
) )

View File

@ -8,28 +8,27 @@ class NotionConfig(BaseSettings):
""" """
Notion integration configs Notion integration configs
""" """
NOTION_CLIENT_ID: Optional[str] = Field( NOTION_CLIENT_ID: Optional[str] = Field(
description="Notion client ID", description='Notion client ID',
default=None, default=None,
) )
NOTION_CLIENT_SECRET: Optional[str] = Field( NOTION_CLIENT_SECRET: Optional[str] = Field(
description="Notion client secret key", description='Notion client secret key',
default=None, default=None,
) )
NOTION_INTEGRATION_TYPE: Optional[str] = Field( NOTION_INTEGRATION_TYPE: Optional[str] = Field(
description="Notion integration type, default to None, available values: internal.", description='Notion integration type, default to None, available values: internal.',
default=None, default=None,
) )
NOTION_INTERNAL_SECRET: Optional[str] = Field( NOTION_INTERNAL_SECRET: Optional[str] = Field(
description="Notion internal secret key", description='Notion internal secret key',
default=None, default=None,
) )
NOTION_INTEGRATION_TOKEN: Optional[str] = Field( NOTION_INTEGRATION_TOKEN: Optional[str] = Field(
description="Notion integration token", description='Notion integration token',
default=None, default=None,
) )

View File

@ -8,18 +8,17 @@ class SentryConfig(BaseSettings):
""" """
Sentry configs Sentry configs
""" """
SENTRY_DSN: Optional[str] = Field( SENTRY_DSN: Optional[str] = Field(
description="Sentry DSN", description='Sentry DSN',
default=None, default=None,
) )
SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field( SENTRY_TRACES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sentry trace sample rate", description='Sentry trace sample rate',
default=1.0, default=1.0,
) )
SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field( SENTRY_PROFILES_SAMPLE_RATE: NonNegativeFloat = Field(
description="Sentry profiles sample rate", description='Sentry profiles sample rate',
default=1.0, default=1.0,
) )

View File

@ -1,6 +1,6 @@
from typing import Annotated, Optional from typing import Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field from pydantic import AliasChoices, Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig from configs.feature.hosted_service import HostedServiceConfig
@ -10,17 +10,16 @@ class SecurityConfig(BaseSettings):
""" """
Secret Key configs Secret Key configs
""" """
SECRET_KEY: Optional[str] = Field( SECRET_KEY: Optional[str] = Field(
description="Your App secret key will be used for securely signing the session cookie" description='Your App secret key will be used for securely signing the session cookie'
"Make sure you are changing this key for your deployment with a strong key." 'Make sure you are changing this key for your deployment with a strong key.'
"You can generate a strong key using `openssl rand -base64 42`." 'You can generate a strong key using `openssl rand -base64 42`.'
"Alternatively you can set it with `SECRET_KEY` environment variable.", 'Alternatively you can set it with `SECRET_KEY` environment variable.',
default=None, default=None,
) )
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field( RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
description="Expiry time in hours for reset token", description='Expiry time in hours for reset token',
default=24, default=24,
) )
@ -29,13 +28,12 @@ class AppExecutionConfig(BaseSettings):
""" """
App Execution configs App Execution configs
""" """
APP_MAX_EXECUTION_TIME: PositiveInt = Field( APP_MAX_EXECUTION_TIME: PositiveInt = Field(
description="execution timeout in seconds for app execution", description='execution timeout in seconds for app execution',
default=1200, default=1200,
) )
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field( APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="max active request per app, 0 means unlimited", description='max active request per app, 0 means unlimited',
default=0, default=0,
) )
@ -44,70 +42,14 @@ class CodeExecutionSandboxConfig(BaseSettings):
""" """
Code Execution Sandbox configs Code Execution Sandbox configs
""" """
CODE_EXECUTION_ENDPOINT: str = Field(
CODE_EXECUTION_ENDPOINT: HttpUrl = Field( description='endpoint URL of code execution servcie',
description="endpoint URL of code execution service", default='http://sandbox:8194',
default="http://sandbox:8194",
) )
CODE_EXECUTION_API_KEY: str = Field( CODE_EXECUTION_API_KEY: str = Field(
description="API key for code execution service", description='API key for code execution service',
default="dify-sandbox", default='dify-sandbox',
)
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
description="connect timeout in seconds for code execution request",
default=10.0,
)
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
description="read timeout in seconds for code execution request",
default=60.0,
)
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
description="write timeout in seconds for code execution request",
default=10.0,
)
CODE_MAX_NUMBER: PositiveInt = Field(
description="max depth for code execution",
default=9223372036854775807,
)
CODE_MIN_NUMBER: NegativeInt = Field(
description="",
default=-9223372036854775807,
)
CODE_MAX_DEPTH: PositiveInt = Field(
description="max depth for code execution",
default=5,
)
CODE_MAX_PRECISION: PositiveInt = Field(
description="max precision digits for float type in code execution",
default=20,
)
CODE_MAX_STRING_LENGTH: PositiveInt = Field(
description="max string length for code execution",
default=80000,
)
CODE_MAX_STRING_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=30,
)
CODE_MAX_OBJECT_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=30,
)
CODE_MAX_NUMBER_ARRAY_LENGTH: PositiveInt = Field(
description="",
default=1000,
) )
@ -115,27 +57,28 @@ class EndpointConfig(BaseSettings):
""" """
Module URL configs Module URL configs
""" """
CONSOLE_API_URL: str = Field( CONSOLE_API_URL: str = Field(
description="The backend URL prefix of the console API." description='The backend URL prefix of the console API.'
"used to concatenate the login authorization callback or notion integration callback.", 'used to concatenate the login authorization callback or notion integration callback.',
default="", default='',
) )
CONSOLE_WEB_URL: str = Field( CONSOLE_WEB_URL: str = Field(
description="The front-end URL prefix of the console web." description='The front-end URL prefix of the console web.'
"used to concatenate some front-end addresses and for CORS configuration use.", 'used to concatenate some front-end addresses and for CORS configuration use.',
default="", default='',
) )
SERVICE_API_URL: str = Field( SERVICE_API_URL: str = Field(
description="Service API Url prefix." "used to display Service API Base Url to the front-end.", description='Service API Url prefix.'
default="", 'used to display Service API Base Url to the front-end.',
default='',
) )
APP_WEB_URL: str = Field( APP_WEB_URL: str = Field(
description="WebApp Url prefix." "used to display WebAPP API Base Url to the front-end.", description='WebApp Url prefix.'
default="", 'used to display WebAPP API Base Url to the front-end.',
default='',
) )
@ -143,18 +86,17 @@ class FileAccessConfig(BaseSettings):
""" """
File Access configs File Access configs
""" """
FILES_URL: str = Field( FILES_URL: str = Field(
description="File preview or download Url prefix." description='File preview or download Url prefix.'
" used to display File preview or download Url to the front-end or as Multi-model inputs;" ' used to display File preview or download Url to the front-end or as Multi-model inputs;'
"Url is signed and has expiration time.", 'Url is signed and has expiration time.',
validation_alias=AliasChoices("FILES_URL", "CONSOLE_API_URL"), validation_alias=AliasChoices('FILES_URL', 'CONSOLE_API_URL'),
alias_priority=1, alias_priority=1,
default="", default='',
) )
FILES_ACCESS_TIMEOUT: int = Field( FILES_ACCESS_TIMEOUT: int = Field(
description="timeout in seconds for file accessing", description='timeout in seconds for file accessing',
default=300, default=300,
) )
@ -163,24 +105,23 @@ class FileUploadConfig(BaseSettings):
""" """
File Uploading configs File Uploading configs
""" """
UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field( UPLOAD_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="size limit in Megabytes for uploading files", description='size limit in Megabytes for uploading files',
default=15, default=15,
) )
UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field( UPLOAD_FILE_BATCH_LIMIT: NonNegativeInt = Field(
description="batch size limit for uploading files", description='batch size limit for uploading files',
default=5, default=5,
) )
UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( UPLOAD_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field(
description="image file size limit in Megabytes for uploading files", description='image file size limit in Megabytes for uploading files',
default=10, default=10,
) )
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field( BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
description="", # todo: to be clarified description='', # todo: to be clarified
default=20, default=20,
) )
@ -189,79 +130,45 @@ class HttpConfig(BaseSettings):
""" """
HTTP configs HTTP configs
""" """
API_COMPRESSION_ENABLED: bool = Field( API_COMPRESSION_ENABLED: bool = Field(
description="whether to enable HTTP response compression of gzip", description='whether to enable HTTP response compression of gzip',
default=False, default=False,
) )
inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field( inner_CONSOLE_CORS_ALLOW_ORIGINS: str = Field(
description="", description='',
validation_alias=AliasChoices("CONSOLE_CORS_ALLOW_ORIGINS", "CONSOLE_WEB_URL"), validation_alias=AliasChoices('CONSOLE_CORS_ALLOW_ORIGINS', 'CONSOLE_WEB_URL'),
default="", default='',
) )
@computed_field @computed_field
@property @property
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(',')
inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field( inner_WEB_API_CORS_ALLOW_ORIGINS: str = Field(
description="", description='',
validation_alias=AliasChoices("WEB_API_CORS_ALLOW_ORIGINS"), validation_alias=AliasChoices('WEB_API_CORS_ALLOW_ORIGINS'),
default="*", default='*',
) )
@computed_field @computed_field
@property @property
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(',')
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
] = 10
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
] = 60
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
] = 20
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
description="",
default=10 * 1024 * 1024,
)
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: PositiveInt = Field(
description="",
default=1 * 1024 * 1024,
)
SSRF_PROXY_HTTP_URL: Optional[str] = Field(
description="HTTP URL for SSRF proxy",
default=None,
)
SSRF_PROXY_HTTPS_URL: Optional[str] = Field(
description="HTTPS URL for SSRF proxy",
default=None,
)
class InnerAPIConfig(BaseSettings): class InnerAPIConfig(BaseSettings):
""" """
Inner API configs Inner API configs
""" """
INNER_API: bool = Field( INNER_API: bool = Field(
description="whether to enable the inner API", description='whether to enable the inner API',
default=False, default=False,
) )
INNER_API_KEY: Optional[str] = Field( INNER_API_KEY: Optional[str] = Field(
description="The inner API key is used to authenticate the inner API", description='The inner API key is used to authenticate the inner API',
default=None, default=None,
) )
@ -272,27 +179,28 @@ class LoggingConfig(BaseSettings):
""" """
LOG_LEVEL: str = Field( LOG_LEVEL: str = Field(
description="Log output level, default to INFO." "It is recommended to set it to ERROR for production.", description='Log output level, default to INFO.'
default="INFO", 'It is recommended to set it to ERROR for production.',
default='INFO',
) )
LOG_FILE: Optional[str] = Field( LOG_FILE: Optional[str] = Field(
description="logging output file path", description='logging output file path',
default=None, default=None,
) )
LOG_FORMAT: str = Field( LOG_FORMAT: str = Field(
description="log format", description='log format',
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s", default='%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s',
) )
LOG_DATEFORMAT: Optional[str] = Field( LOG_DATEFORMAT: Optional[str] = Field(
description="log date format", description='log date format',
default=None, default=None,
) )
LOG_TZ: Optional[str] = Field( LOG_TZ: Optional[str] = Field(
description="specify log timezone, eg: America/New_York", description='specify log timezone, eg: America/New_York',
default=None, default=None,
) )
@ -301,9 +209,8 @@ class ModelLoadBalanceConfig(BaseSettings):
""" """
Model load balance configs Model load balance configs
""" """
MODEL_LB_ENABLED: bool = Field( MODEL_LB_ENABLED: bool = Field(
description="whether to enable model load balancing", description='whether to enable model load balancing',
default=False, default=False,
) )
@ -312,9 +219,8 @@ class BillingConfig(BaseSettings):
""" """
Platform Billing Configurations Platform Billing Configurations
""" """
BILLING_ENABLED: bool = Field( BILLING_ENABLED: bool = Field(
description="whether to enable billing", description='whether to enable billing',
default=False, default=False,
) )
@ -323,10 +229,9 @@ class UpdateConfig(BaseSettings):
""" """
Update configs Update configs
""" """
CHECK_UPDATE_URL: str = Field( CHECK_UPDATE_URL: str = Field(
description="url for checking updates", description='url for checking updates',
default="https://updates.dify.ai", default='https://updates.dify.ai',
) )
@ -336,53 +241,47 @@ class WorkflowConfig(BaseSettings):
""" """
WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field( WORKFLOW_MAX_EXECUTION_STEPS: PositiveInt = Field(
description="max execution steps in single workflow execution", description='max execution steps in single workflow execution',
default=500, default=500,
) )
WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field( WORKFLOW_MAX_EXECUTION_TIME: PositiveInt = Field(
description="max execution time in seconds in single workflow execution", description='max execution time in seconds in single workflow execution',
default=1200, default=1200,
) )
WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field( WORKFLOW_CALL_MAX_DEPTH: PositiveInt = Field(
description="max depth of calling in single workflow execution", description='max depth of calling in single workflow execution',
default=5, default=5,
) )
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="The maximum size in bytes of a variable. default to 5KB.",
default=5 * 1024,
)
class OAuthConfig(BaseSettings): class OAuthConfig(BaseSettings):
""" """
oauth configs oauth configs
""" """
OAUTH_REDIRECT_PATH: str = Field( OAUTH_REDIRECT_PATH: str = Field(
description="redirect path for OAuth", description='redirect path for OAuth',
default="/console/api/oauth/authorize", default='/console/api/oauth/authorize',
) )
GITHUB_CLIENT_ID: Optional[str] = Field( GITHUB_CLIENT_ID: Optional[str] = Field(
description="GitHub client id for OAuth", description='GitHub client id for OAuth',
default=None, default=None,
) )
GITHUB_CLIENT_SECRET: Optional[str] = Field( GITHUB_CLIENT_SECRET: Optional[str] = Field(
description="GitHub client secret key for OAuth", description='GitHub client secret key for OAuth',
default=None, default=None,
) )
GOOGLE_CLIENT_ID: Optional[str] = Field( GOOGLE_CLIENT_ID: Optional[str] = Field(
description="Google client id for OAuth", description='Google client id for OAuth',
default=None, default=None,
) )
GOOGLE_CLIENT_SECRET: Optional[str] = Field( GOOGLE_CLIENT_SECRET: Optional[str] = Field(
description="Google client secret key for OAuth", description='Google client secret key for OAuth',
default=None, default=None,
) )
@ -392,8 +291,9 @@ class ModerationConfig(BaseSettings):
Moderation in app configs. Moderation in app configs.
""" """
MODERATION_BUFFER_SIZE: PositiveInt = Field( # todo: to be clarified in usage and unit
description="buffer size for moderation", OUTPUT_MODERATION_BUFFER_SIZE: PositiveInt = Field(
description='buffer size for moderation',
default=300, default=300,
) )
@ -404,7 +304,7 @@ class ToolConfig(BaseSettings):
""" """
TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field( TOOL_ICON_CACHE_MAX_AGE: PositiveInt = Field(
description="max age in seconds for tool icon caching", description='max age in seconds for tool icon caching',
default=3600, default=3600,
) )
@ -415,52 +315,52 @@ class MailConfig(BaseSettings):
""" """
MAIL_TYPE: Optional[str] = Field( MAIL_TYPE: Optional[str] = Field(
description="Mail provider type name, default to None, available values are `smtp` and `resend`.", description='Mail provider type name, default to None, availabile values are `smtp` and `resend`.',
default=None, default=None,
) )
MAIL_DEFAULT_SEND_FROM: Optional[str] = Field( MAIL_DEFAULT_SEND_FROM: Optional[str] = Field(
description="default email address for sending from ", description='default email address for sending from ',
default=None, default=None,
) )
RESEND_API_KEY: Optional[str] = Field( RESEND_API_KEY: Optional[str] = Field(
description="API key for Resend", description='API key for Resend',
default=None, default=None,
) )
RESEND_API_URL: Optional[str] = Field( RESEND_API_URL: Optional[str] = Field(
description="API URL for Resend", description='API URL for Resend',
default=None, default=None,
) )
SMTP_SERVER: Optional[str] = Field( SMTP_SERVER: Optional[str] = Field(
description="smtp server host", description='smtp server host',
default=None, default=None,
) )
SMTP_PORT: Optional[int] = Field( SMTP_PORT: Optional[int] = Field(
description="smtp server port", description='smtp server port',
default=465, default=465,
) )
SMTP_USERNAME: Optional[str] = Field( SMTP_USERNAME: Optional[str] = Field(
description="smtp server username", description='smtp server username',
default=None, default=None,
) )
SMTP_PASSWORD: Optional[str] = Field( SMTP_PASSWORD: Optional[str] = Field(
description="smtp server password", description='smtp server password',
default=None, default=None,
) )
SMTP_USE_TLS: bool = Field( SMTP_USE_TLS: bool = Field(
description="whether to use TLS connection to smtp server", description='whether to use TLS connection to smtp server',
default=False, default=False,
) )
SMTP_OPPORTUNISTIC_TLS: bool = Field( SMTP_OPPORTUNISTIC_TLS: bool = Field(
description="whether to use opportunistic TLS connection to smtp server", description='whether to use opportunistic TLS connection to smtp server',
default=False, default=False,
) )
@ -471,22 +371,22 @@ class RagEtlConfig(BaseSettings):
""" """
ETL_TYPE: str = Field( ETL_TYPE: str = Field(
description="RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ", description='RAG ETL type name, default to `dify`, available values are `dify` and `Unstructured`. ',
default="dify", default='dify',
) )
KEYWORD_DATA_SOURCE_TYPE: str = Field( KEYWORD_DATA_SOURCE_TYPE: str = Field(
description="source type for keyword data, default to `database`, available values are `database` .", description='source type for keyword data, default to `database`, available values are `database` .',
default="database", default='database',
) )
UNSTRUCTURED_API_URL: Optional[str] = Field( UNSTRUCTURED_API_URL: Optional[str] = Field(
description="API URL for Unstructured", description='API URL for Unstructured',
default=None, default=None,
) )
UNSTRUCTURED_API_KEY: Optional[str] = Field( UNSTRUCTURED_API_KEY: Optional[str] = Field(
description="API key for Unstructured", description='API key for Unstructured',
default=None, default=None,
) )
@ -497,23 +397,22 @@ class DataSetConfig(BaseSettings):
""" """
CLEAN_DAY_SETTING: PositiveInt = Field( CLEAN_DAY_SETTING: PositiveInt = Field(
description="interval in days for cleaning up dataset", description='interval in days for cleaning up dataset',
default=30, default=30,
) )
DATASET_OPERATOR_ENABLED: bool = Field( DATASET_OPERATOR_ENABLED: bool = Field(
description="whether to enable dataset operator", description='whether to enable dataset operator',
default=False, default=False,
) )
class WorkspaceConfig(BaseSettings): class WorkspaceConfig(BaseSettings):
""" """
Workspace configs Workspace configs
""" """
INVITE_EXPIRY_HOURS: PositiveInt = Field( INVITE_EXPIRY_HOURS: PositiveInt = Field(
description="workspaces invitation expiration in hours", description='workspaces invitation expiration in hours',
default=72, default=72,
) )
@ -524,81 +423,25 @@ class IndexingConfig(BaseSettings):
""" """
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field( INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field(
description="max segmentation token length for indexing", description='max segmentation token length for indexing',
default=1000, default=1000,
) )
class ImageFormatConfig(BaseSettings): class ImageFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field( MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
description="multi model send image format, support base64, url, default is base64", description='multi model send image format, support base64, url, default is base64',
default="base64", default='base64',
) )
class CeleryBeatConfig(BaseSettings): class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field( CELERY_BEAT_SCHEDULER_TIME: int = Field(
description="the time of the celery scheduler, default to 1 day", description='the time of the celery scheduler, default to 1 day',
default=1, default=1,
) )
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description="The heads of model providers",
default="",
)
POSITION_PROVIDER_INCLUDES: str = Field(
description="The included model providers",
default="",
)
POSITION_PROVIDER_EXCLUDES: str = Field(
description="The excluded model providers",
default="",
)
POSITION_TOOL_PINS: str = Field(
description="The heads of tools",
default="",
)
POSITION_TOOL_INCLUDES: str = Field(
description="The included tools",
default="",
)
POSITION_TOOL_EXCLUDES: str = Field(
description="The excluded tools",
default="",
)
@computed_field
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
@computed_field
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
@computed_field
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
@computed_field
def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
@computed_field
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
@computed_field
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
AppExecutionConfig, AppExecutionConfig,
@ -623,7 +466,7 @@ class FeatureConfig(
UpdateConfig, UpdateConfig,
WorkflowConfig, WorkflowConfig,
WorkspaceConfig, WorkspaceConfig,
PositionConfig,
# hosted services config # hosted services config
HostedServiceConfig, HostedServiceConfig,
CeleryBeatConfig, CeleryBeatConfig,

View File

@ -10,62 +10,62 @@ class HostedOpenAiConfig(BaseSettings):
""" """
HOSTED_OPENAI_API_KEY: Optional[str] = Field( HOSTED_OPENAI_API_KEY: Optional[str] = Field(
description="", description='',
default=None, default=None,
) )
HOSTED_OPENAI_API_BASE: Optional[str] = Field( HOSTED_OPENAI_API_BASE: Optional[str] = Field(
description="", description='',
default=None, default=None,
) )
HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field( HOSTED_OPENAI_API_ORGANIZATION: Optional[str] = Field(
description="", description='',
default=None, default=None,
) )
HOSTED_OPENAI_TRIAL_ENABLED: bool = Field( HOSTED_OPENAI_TRIAL_ENABLED: bool = Field(
description="", description='',
default=False, default=False,
) )
HOSTED_OPENAI_TRIAL_MODELS: str = Field( HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="", description='',
default="gpt-3.5-turbo," default='gpt-3.5-turbo,'
"gpt-3.5-turbo-1106," 'gpt-3.5-turbo-1106,'
"gpt-3.5-turbo-instruct," 'gpt-3.5-turbo-instruct,'
"gpt-3.5-turbo-16k," 'gpt-3.5-turbo-16k,'
"gpt-3.5-turbo-16k-0613," 'gpt-3.5-turbo-16k-0613,'
"gpt-3.5-turbo-0613," 'gpt-3.5-turbo-0613,'
"gpt-3.5-turbo-0125," 'gpt-3.5-turbo-0125,'
"text-davinci-003", 'text-davinci-003',
) )
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="", description='',
default=200, default=200,
) )
HOSTED_OPENAI_PAID_ENABLED: bool = Field( HOSTED_OPENAI_PAID_ENABLED: bool = Field(
description="", description='',
default=False, default=False,
) )
HOSTED_OPENAI_PAID_MODELS: str = Field( HOSTED_OPENAI_PAID_MODELS: str = Field(
description="", description='',
default="gpt-4," default='gpt-4,'
"gpt-4-turbo-preview," 'gpt-4-turbo-preview,'
"gpt-4-turbo-2024-04-09," 'gpt-4-turbo-2024-04-09,'
"gpt-4-1106-preview," 'gpt-4-1106-preview,'
"gpt-4-0125-preview," 'gpt-4-0125-preview,'
"gpt-3.5-turbo," 'gpt-3.5-turbo,'
"gpt-3.5-turbo-16k," 'gpt-3.5-turbo-16k,'
"gpt-3.5-turbo-16k-0613," 'gpt-3.5-turbo-16k-0613,'
"gpt-3.5-turbo-1106," 'gpt-3.5-turbo-1106,'
"gpt-3.5-turbo-0613," 'gpt-3.5-turbo-0613,'
"gpt-3.5-turbo-0125," 'gpt-3.5-turbo-0125,'
"gpt-3.5-turbo-instruct," 'gpt-3.5-turbo-instruct,'
"text-davinci-003", 'text-davinci-003',
) )
@ -75,22 +75,22 @@ class HostedAzureOpenAiConfig(BaseSettings):
""" """
HOSTED_AZURE_OPENAI_ENABLED: bool = Field( HOSTED_AZURE_OPENAI_ENABLED: bool = Field(
description="", description='',
default=False, default=False,
) )
HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field( HOSTED_AZURE_OPENAI_API_KEY: Optional[str] = Field(
description="", description='',
default=None, default=None,
) )
HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field( HOSTED_AZURE_OPENAI_API_BASE: Optional[str] = Field(
description="", description='',
default=None, default=None,
) )
HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field( HOSTED_AZURE_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="", description='',
default=200, default=200,
) )
@ -101,27 +101,27 @@ class HostedAnthropicConfig(BaseSettings):
""" """
HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field( HOSTED_ANTHROPIC_API_BASE: Optional[str] = Field(
description="", description='',
default=None, default=None,
) )
HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field( HOSTED_ANTHROPIC_API_KEY: Optional[str] = Field(
description="", description='',
default=None, default=None,
) )
HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field( HOSTED_ANTHROPIC_TRIAL_ENABLED: bool = Field(
description="", description='',
default=False, default=False,
) )
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field( HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="", description='',
default=600000, default=600000,
) )
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field( HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="", description='',
default=False, default=False,
) )
@ -132,7 +132,7 @@ class HostedMinmaxConfig(BaseSettings):
""" """
HOSTED_MINIMAX_ENABLED: bool = Field( HOSTED_MINIMAX_ENABLED: bool = Field(
description="", description='',
default=False, default=False,
) )
@ -143,7 +143,7 @@ class HostedSparkConfig(BaseSettings):
""" """
HOSTED_SPARK_ENABLED: bool = Field( HOSTED_SPARK_ENABLED: bool = Field(
description="", description='',
default=False, default=False,
) )
@ -154,7 +154,7 @@ class HostedZhipuAIConfig(BaseSettings):
""" """
HOSTED_ZHIPUAI_ENABLED: bool = Field( HOSTED_ZHIPUAI_ENABLED: bool = Field(
description="", description='',
default=False, default=False,
) )
@ -165,13 +165,13 @@ class HostedModerationConfig(BaseSettings):
""" """
HOSTED_MODERATION_ENABLED: bool = Field( HOSTED_MODERATION_ENABLED: bool = Field(
description="", description='',
default=False, default=False,
) )
HOSTED_MODERATION_PROVIDERS: str = Field( HOSTED_MODERATION_PROVIDERS: str = Field(
description="", description='',
default="", default='',
) )
@ -181,15 +181,15 @@ class HostedFetchAppTemplateConfig(BaseSettings):
""" """
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field( HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description="the mode for fetching app templates," description='the mode for fetching app templates,'
" default to remote," ' default to remote,'
" available values: remote, db, builtin", ' available values: remote, db, builtin',
default="remote", default='remote',
) )
HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field( HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="the domain for fetching remote app templates", description='the domain for fetching remote app templates',
default="https://tmpl.dify.ai", default='https://tmpl.dify.ai',
) )
@ -202,6 +202,7 @@ class HostedServiceConfig(
HostedOpenAiConfig, HostedOpenAiConfig,
HostedSparkConfig, HostedSparkConfig,
HostedZhipuAIConfig, HostedZhipuAIConfig,
# moderation # moderation
HostedModerationConfig, HostedModerationConfig,
): ):

View File

@ -1,7 +1,7 @@
from typing import Any, Optional from typing import Any, Optional
from urllib.parse import quote_plus from urllib.parse import quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig from configs.middleware.cache.redis_config import RedisConfig
@ -9,13 +9,10 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig from configs.middleware.vdb.opensearch_config import OpenSearchConfig
@ -31,108 +28,108 @@ from configs.middleware.vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings): class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field( STORAGE_TYPE: str = Field(
description="storage type," description='storage type,'
" default to `local`," ' default to `local`,'
" available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.", ' available values are `local`, `s3`, `azure-blob`, `aliyun-oss`, `google-storage`.',
default="local", default='local',
) )
STORAGE_LOCAL_PATH: str = Field( STORAGE_LOCAL_PATH: str = Field(
description="local storage path", description='local storage path',
default="storage", default='storage',
) )
class VectorStoreConfig(BaseSettings): class VectorStoreConfig(BaseSettings):
VECTOR_STORE: Optional[str] = Field( VECTOR_STORE: Optional[str] = Field(
description="vector store type", description='vector store type',
default=None, default=None,
) )
class KeywordStoreConfig(BaseSettings): class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field( KEYWORD_STORE: str = Field(
description="keyword store type", description='keyword store type',
default="jieba", default='jieba',
) )
class DatabaseConfig: class DatabaseConfig:
DB_HOST: str = Field( DB_HOST: str = Field(
description="db host", description='db host',
default="localhost", default='localhost',
) )
DB_PORT: PositiveInt = Field( DB_PORT: PositiveInt = Field(
description="db port", description='db port',
default=5432, default=5432,
) )
DB_USERNAME: str = Field( DB_USERNAME: str = Field(
description="db username", description='db username',
default="postgres", default='postgres',
) )
DB_PASSWORD: str = Field( DB_PASSWORD: str = Field(
description="db password", description='db password',
default="", default='',
) )
DB_DATABASE: str = Field( DB_DATABASE: str = Field(
description="db database", description='db database',
default="dify", default='dify',
) )
DB_CHARSET: str = Field( DB_CHARSET: str = Field(
description="db charset", description='db charset',
default="", default='',
) )
DB_EXTRAS: str = Field( DB_EXTRAS: str = Field(
description="db extras options. Example: keepalives_idle=60&keepalives=1", description='db extras options. Example: keepalives_idle=60&keepalives=1',
default="", default='',
) )
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field( SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
description="db uri scheme", description='db uri scheme',
default="postgresql", default='postgresql',
) )
@computed_field @computed_field
@property @property
def SQLALCHEMY_DATABASE_URI(self) -> str: def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = ( db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}"
if self.DB_CHARSET
else self.DB_EXTRAS
).strip("&") ).strip("&")
db_extras = f"?{db_extras}" if db_extras else "" db_extras = f"?{db_extras}" if db_extras else ""
return ( return (f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://" f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}" f"{db_extras}")
f"{db_extras}"
)
SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field( SQLALCHEMY_POOL_SIZE: NonNegativeInt = Field(
description="pool size of SqlAlchemy", description='pool size of SqlAlchemy',
default=30, default=30,
) )
SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field( SQLALCHEMY_MAX_OVERFLOW: NonNegativeInt = Field(
description="max overflows for SqlAlchemy", description='max overflows for SqlAlchemy',
default=10, default=10,
) )
SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field( SQLALCHEMY_POOL_RECYCLE: NonNegativeInt = Field(
description="SqlAlchemy pool recycle", description='SqlAlchemy pool recycle',
default=3600, default=3600,
) )
SQLALCHEMY_POOL_PRE_PING: bool = Field( SQLALCHEMY_POOL_PRE_PING: bool = Field(
description="whether to enable pool pre-ping in SqlAlchemy", description='whether to enable pool pre-ping in SqlAlchemy',
default=False, default=False,
) )
SQLALCHEMY_ECHO: bool | str = Field( SQLALCHEMY_ECHO: bool | str = Field(
description="whether to enable SqlAlchemy echo", description='whether to enable SqlAlchemy echo',
default=False, default=False,
) )
@ -140,53 +137,35 @@ class DatabaseConfig:
@property @property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
return { return {
"pool_size": self.SQLALCHEMY_POOL_SIZE, 'pool_size': self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, 'max_overflow': self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, 'pool_recycle': self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, 'pool_pre_ping': self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": {"options": "-c timezone=UTC"}, 'connect_args': {'options': '-c timezone=UTC'},
} }
class CeleryConfig(DatabaseConfig): class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field( CELERY_BACKEND: str = Field(
description="Celery backend, available values are `database`, `redis`", description='Celery backend, available values are `database`, `redis`',
default="database", default='database',
) )
CELERY_BROKER_URL: Optional[str] = Field( CELERY_BROKER_URL: Optional[str] = Field(
description="CELERY_BROKER_URL", description='CELERY_BROKER_URL',
default=None, default=None,
) )
CELERY_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
description="Redis Sentinel master name",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)
@computed_field @computed_field
@property @property
def CELERY_RESULT_BACKEND(self) -> str | None: def CELERY_RESULT_BACKEND(self) -> str | None:
return ( return 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
"db+{}".format(self.SQLALCHEMY_DATABASE_URI) if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
if self.CELERY_BACKEND == "database"
else self.CELERY_BROKER_URL
)
@computed_field @computed_field
@property @property
def BROKER_USE_SSL(self) -> bool: def BROKER_USE_SSL(self) -> bool:
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False return self.CELERY_BROKER_URL.startswith('rediss://') if self.CELERY_BROKER_URL else False
class MiddlewareConfig( class MiddlewareConfig(
@ -195,16 +174,16 @@ class MiddlewareConfig(
DatabaseConfig, DatabaseConfig,
KeywordStoreConfig, KeywordStoreConfig,
RedisConfig, RedisConfig,
# configs of storage and storage providers # configs of storage and storage providers
StorageConfig, StorageConfig,
AliyunOSSStorageConfig, AliyunOSSStorageConfig,
AzureBlobStorageConfig, AzureBlobStorageConfig,
GoogleCloudStorageConfig, GoogleCloudStorageConfig,
TencentCloudCOSStorageConfig, TencentCloudCOSStorageConfig,
HuaweiCloudOBSStorageConfig,
VolcengineTOSStorageConfig,
S3StorageConfig, S3StorageConfig,
OCIStorageConfig, OCIStorageConfig,
# configs of vdb and vdb providers # configs of vdb and vdb providers
VectorStoreConfig, VectorStoreConfig,
AnalyticdbConfig, AnalyticdbConfig,
@ -220,6 +199,5 @@ class MiddlewareConfig(
TencentVectorDBConfig, TencentVectorDBConfig,
TiDBVectorConfig, TiDBVectorConfig,
WeaviateConfig, WeaviateConfig,
ElasticsearchConfig,
): ):
pass pass

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -8,63 +8,32 @@ class RedisConfig(BaseSettings):
""" """
Redis configs Redis configs
""" """
REDIS_HOST: str = Field( REDIS_HOST: str = Field(
description="Redis host", description='Redis host',
default="localhost", default='localhost',
) )
REDIS_PORT: PositiveInt = Field( REDIS_PORT: PositiveInt = Field(
description="Redis port", description='Redis port',
default=6379, default=6379,
) )
REDIS_USERNAME: Optional[str] = Field( REDIS_USERNAME: Optional[str] = Field(
description="Redis username", description='Redis username',
default=None, default=None,
) )
REDIS_PASSWORD: Optional[str] = Field( REDIS_PASSWORD: Optional[str] = Field(
description="Redis password", description='Redis password',
default=None, default=None,
) )
REDIS_DB: NonNegativeInt = Field( REDIS_DB: NonNegativeInt = Field(
description="Redis database id, default to 0", description='Redis database id, default to 0',
default=0, default=0,
) )
REDIS_USE_SSL: bool = Field( REDIS_USE_SSL: bool = Field(
description="whether to use SSL for Redis connection", description='whether to use SSL for Redis connection',
default=False, default=False,
) )
REDIS_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
REDIS_SENTINELS: Optional[str] = Field(
description="Redis Sentinel nodes",
default=None,
)
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
description="Redis Sentinel service name",
default=None,
)
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
description="Redis Sentinel username",
default=None,
)
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
description="Redis Sentinel password",
default=None,
)
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)

View File

@ -10,36 +10,31 @@ class AliyunOSSStorageConfig(BaseSettings):
""" """
ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field( ALIYUN_OSS_BUCKET_NAME: Optional[str] = Field(
description="Aliyun OSS bucket name", description='Aliyun OSS bucket name',
default=None, default=None,
) )
ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field( ALIYUN_OSS_ACCESS_KEY: Optional[str] = Field(
description="Aliyun OSS access key", description='Aliyun OSS access key',
default=None, default=None,
) )
ALIYUN_OSS_SECRET_KEY: Optional[str] = Field( ALIYUN_OSS_SECRET_KEY: Optional[str] = Field(
description="Aliyun OSS secret key", description='Aliyun OSS secret key',
default=None, default=None,
) )
ALIYUN_OSS_ENDPOINT: Optional[str] = Field( ALIYUN_OSS_ENDPOINT: Optional[str] = Field(
description="Aliyun OSS endpoint URL", description='Aliyun OSS endpoint URL',
default=None, default=None,
) )
ALIYUN_OSS_REGION: Optional[str] = Field( ALIYUN_OSS_REGION: Optional[str] = Field(
description="Aliyun OSS region", description='Aliyun OSS region',
default=None, default=None,
) )
ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field( ALIYUN_OSS_AUTH_VERSION: Optional[str] = Field(
description="Aliyun OSS authentication version", description='Aliyun OSS authentication version',
default=None,
)
ALIYUN_OSS_PATH: Optional[str] = Field(
description="Aliyun OSS path",
default=None, default=None,
) )

View File

@ -10,36 +10,36 @@ class S3StorageConfig(BaseSettings):
""" """
S3_ENDPOINT: Optional[str] = Field( S3_ENDPOINT: Optional[str] = Field(
description="S3 storage endpoint", description='S3 storage endpoint',
default=None, default=None,
) )
S3_REGION: Optional[str] = Field( S3_REGION: Optional[str] = Field(
description="S3 storage region", description='S3 storage region',
default=None, default=None,
) )
S3_BUCKET_NAME: Optional[str] = Field( S3_BUCKET_NAME: Optional[str] = Field(
description="S3 storage bucket name", description='S3 storage bucket name',
default=None, default=None,
) )
S3_ACCESS_KEY: Optional[str] = Field( S3_ACCESS_KEY: Optional[str] = Field(
description="S3 storage access key", description='S3 storage access key',
default=None, default=None,
) )
S3_SECRET_KEY: Optional[str] = Field( S3_SECRET_KEY: Optional[str] = Field(
description="S3 storage secret key", description='S3 storage secret key',
default=None, default=None,
) )
S3_ADDRESS_STYLE: str = Field( S3_ADDRESS_STYLE: str = Field(
description="S3 storage address style", description='S3 storage address style',
default="auto", default='auto',
) )
S3_USE_AWS_MANAGED_IAM: bool = Field( S3_USE_AWS_MANAGED_IAM: bool = Field(
description="whether to use aws managed IAM for S3", description='whether to use aws managed IAM for S3',
default=False, default=False,
) )

View File

@ -10,21 +10,21 @@ class AzureBlobStorageConfig(BaseSettings):
""" """
AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field( AZURE_BLOB_ACCOUNT_NAME: Optional[str] = Field(
description="Azure Blob account name", description='Azure Blob account name',
default=None, default=None,
) )
AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field( AZURE_BLOB_ACCOUNT_KEY: Optional[str] = Field(
description="Azure Blob account key", description='Azure Blob account key',
default=None, default=None,
) )
AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field( AZURE_BLOB_CONTAINER_NAME: Optional[str] = Field(
description="Azure Blob container name", description='Azure Blob container name',
default=None, default=None,
) )
AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field( AZURE_BLOB_ACCOUNT_URL: Optional[str] = Field(
description="Azure Blob account URL", description='Azure Blob account URL',
default=None, default=None,
) )

View File

@ -10,11 +10,11 @@ class GoogleCloudStorageConfig(BaseSettings):
""" """
GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field( GOOGLE_STORAGE_BUCKET_NAME: Optional[str] = Field(
description="Google Cloud storage bucket name", description='Google Cloud storage bucket name',
default=None, default=None,
) )
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field( GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64: Optional[str] = Field(
description="Google Cloud storage service account json base64", description='Google Cloud storage service account json base64',
default=None, default=None,
) )

View File

@ -1,29 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class HuaweiCloudOBSStorageConfig(BaseModel):
"""
Huawei Cloud OBS storage configs
"""
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
description="Huawei Cloud OBS bucket name",
default=None,
)
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Access key",
default=None,
)
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Secret key",
default=None,
)
HUAWEI_OBS_SERVER: Optional[str] = Field(
description="Huawei Cloud OBS server URL",
default=None,
)

View File

@ -10,26 +10,27 @@ class OCIStorageConfig(BaseSettings):
""" """
OCI_ENDPOINT: Optional[str] = Field( OCI_ENDPOINT: Optional[str] = Field(
description="OCI storage endpoint", description='OCI storage endpoint',
default=None, default=None,
) )
OCI_REGION: Optional[str] = Field( OCI_REGION: Optional[str] = Field(
description="OCI storage region", description='OCI storage region',
default=None, default=None,
) )
OCI_BUCKET_NAME: Optional[str] = Field( OCI_BUCKET_NAME: Optional[str] = Field(
description="OCI storage bucket name", description='OCI storage bucket name',
default=None, default=None,
) )
OCI_ACCESS_KEY: Optional[str] = Field( OCI_ACCESS_KEY: Optional[str] = Field(
description="OCI storage access key", description='OCI storage access key',
default=None, default=None,
) )
OCI_SECRET_KEY: Optional[str] = Field( OCI_SECRET_KEY: Optional[str] = Field(
description="OCI storage secret key", description='OCI storage secret key',
default=None, default=None,
) )

View File

@ -10,26 +10,26 @@ class TencentCloudCOSStorageConfig(BaseSettings):
""" """
TENCENT_COS_BUCKET_NAME: Optional[str] = Field( TENCENT_COS_BUCKET_NAME: Optional[str] = Field(
description="Tencent Cloud COS bucket name", description='Tencent Cloud COS bucket name',
default=None, default=None,
) )
TENCENT_COS_REGION: Optional[str] = Field( TENCENT_COS_REGION: Optional[str] = Field(
description="Tencent Cloud COS region", description='Tencent Cloud COS region',
default=None, default=None,
) )
TENCENT_COS_SECRET_ID: Optional[str] = Field( TENCENT_COS_SECRET_ID: Optional[str] = Field(
description="Tencent Cloud COS secret id", description='Tencent Cloud COS secret id',
default=None, default=None,
) )
TENCENT_COS_SECRET_KEY: Optional[str] = Field( TENCENT_COS_SECRET_KEY: Optional[str] = Field(
description="Tencent Cloud COS secret key", description='Tencent Cloud COS secret key',
default=None, default=None,
) )
TENCENT_COS_SCHEME: Optional[str] = Field( TENCENT_COS_SCHEME: Optional[str] = Field(
description="Tencent Cloud COS scheme", description='Tencent Cloud COS scheme',
default=None, default=None,
) )

View File

@ -1,34 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class VolcengineTOSStorageConfig(BaseModel):
"""
Volcengine tos storage configs
"""
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
description="Volcengine TOS Bucket Name",
default=None,
)
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
description="Volcengine TOS Access Key",
default=None,
)
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
description="Volcengine TOS Secret Key",
default=None,
)
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
description="Volcengine TOS Endpoint URL",
default=None,
)
VOLCENGINE_TOS_REGION: Optional[str] = Field(
description="Volcengine TOS Region",
default=None,
)

View File

@ -10,28 +10,35 @@ class AnalyticdbConfig(BaseModel):
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
""" """
ANALYTICDB_KEY_ID: Optional[str] = Field( ANALYTICDB_KEY_ID : Optional[str] = Field(
default=None, description="The Access Key ID provided by Alibaba Cloud for authentication."
)
ANALYTICDB_KEY_SECRET: Optional[str] = Field(
default=None, description="The Secret Access Key corresponding to the Access Key ID for secure access."
)
ANALYTICDB_REGION_ID: Optional[str] = Field(
default=None, description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
)
ANALYTICDB_INSTANCE_ID: Optional[str] = Field(
default=None, default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456')..", description="The Access Key ID provided by Alibaba Cloud for authentication."
) )
ANALYTICDB_ACCOUNT: Optional[str] = Field( ANALYTICDB_KEY_SECRET : Optional[str] = Field(
default=None, description="The account name used to log in to the AnalyticDB instance." default=None,
description="The Secret Access Key corresponding to the Access Key ID for secure access."
) )
ANALYTICDB_PASSWORD: Optional[str] = Field( ANALYTICDB_REGION_ID : Optional[str] = Field(
default=None, description="The password associated with the AnalyticDB account for authentication." default=None,
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
) )
ANALYTICDB_NAMESPACE: Optional[str] = Field( ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
default=None, description="The namespace within AnalyticDB for schema isolation." default=None,
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
) )
ANALYTICDB_NAMESPACE_PASSWORD: Optional[str] = Field( ANALYTICDB_ACCOUNT : Optional[str] = Field(
default=None, description="The password for accessing the specified namespace within the AnalyticDB instance." default=None,
description="The account name used to log in to the AnalyticDB instance."
)
ANALYTICDB_PASSWORD : Optional[str] = Field(
default=None,
description="The password associated with the AnalyticDB account for authentication."
)
ANALYTICDB_NAMESPACE : Optional[str] = Field(
default=None,
description="The namespace within AnalyticDB for schema isolation."
)
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
default=None,
description="The password for accessing the specified namespace within the AnalyticDB instance."
) )

View File

@ -10,31 +10,31 @@ class ChromaConfig(BaseSettings):
""" """
CHROMA_HOST: Optional[str] = Field( CHROMA_HOST: Optional[str] = Field(
description="Chroma host", description='Chroma host',
default=None, default=None,
) )
CHROMA_PORT: PositiveInt = Field( CHROMA_PORT: PositiveInt = Field(
description="Chroma port", description='Chroma port',
default=8000, default=8000,
) )
CHROMA_TENANT: Optional[str] = Field( CHROMA_TENANT: Optional[str] = Field(
description="Chroma database", description='Chroma database',
default=None, default=None,
) )
CHROMA_DATABASE: Optional[str] = Field( CHROMA_DATABASE: Optional[str] = Field(
description="Chroma database", description='Chroma database',
default=None, default=None,
) )
CHROMA_AUTH_PROVIDER: Optional[str] = Field( CHROMA_AUTH_PROVIDER: Optional[str] = Field(
description="Chroma authentication provider", description='Chroma authentication provider',
default=None, default=None,
) )
CHROMA_AUTH_CREDENTIALS: Optional[str] = Field( CHROMA_AUTH_CREDENTIALS: Optional[str] = Field(
description="Chroma authentication credentials", description='Chroma authentication credentials',
default=None, default=None,
) )

View File

@ -1,30 +0,0 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class ElasticsearchConfig(BaseSettings):
"""
Elasticsearch configs
"""
ELASTICSEARCH_HOST: Optional[str] = Field(
description="Elasticsearch host",
default="127.0.0.1",
)
ELASTICSEARCH_PORT: PositiveInt = Field(
description="Elasticsearch port",
default=9200,
)
ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Elasticsearch username",
default="elastic",
)
ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Elasticsearch password",
default="elastic",
)

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from pydantic import Field from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -9,27 +9,32 @@ class MilvusConfig(BaseSettings):
Milvus configs Milvus configs
""" """
MILVUS_URI: Optional[str] = Field( MILVUS_HOST: Optional[str] = Field(
description="Milvus uri", description='Milvus host',
default="http://127.0.0.1:19530",
)
MILVUS_TOKEN: Optional[str] = Field(
description="Milvus token",
default=None, default=None,
) )
MILVUS_PORT: PositiveInt = Field(
description='Milvus RestFul API port',
default=9091,
)
MILVUS_USER: Optional[str] = Field( MILVUS_USER: Optional[str] = Field(
description="Milvus user", description='Milvus user',
default=None, default=None,
) )
MILVUS_PASSWORD: Optional[str] = Field( MILVUS_PASSWORD: Optional[str] = Field(
description="Milvus password", description='Milvus password',
default=None, default=None,
) )
MILVUS_DATABASE: str = Field( MILVUS_SECURE: bool = Field(
description="Milvus database, default to `default`", description='whether to use SSL connection for Milvus',
default="default", default=False,
)
MILVUS_DATABASE: str = Field(
description='Milvus database, default to `default`',
default='default',
) )

View File

@ -1,3 +1,4 @@
from pydantic import BaseModel, Field, PositiveInt from pydantic import BaseModel, Field, PositiveInt
@ -7,31 +8,31 @@ class MyScaleConfig(BaseModel):
""" """
MYSCALE_HOST: str = Field( MYSCALE_HOST: str = Field(
description="MyScale host", description='MyScale host',
default="localhost", default='localhost',
) )
MYSCALE_PORT: PositiveInt = Field( MYSCALE_PORT: PositiveInt = Field(
description="MyScale port", description='MyScale port',
default=8123, default=8123,
) )
MYSCALE_USER: str = Field( MYSCALE_USER: str = Field(
description="MyScale user", description='MyScale user',
default="default", default='default',
) )
MYSCALE_PASSWORD: str = Field( MYSCALE_PASSWORD: str = Field(
description="MyScale password", description='MyScale password',
default="", default='',
) )
MYSCALE_DATABASE: str = Field( MYSCALE_DATABASE: str = Field(
description="MyScale database name", description='MyScale database name',
default="default", default='default',
) )
MYSCALE_FTS_PARAMS: str = Field( MYSCALE_FTS_PARAMS: str = Field(
description="MyScale fts index parameters", description='MyScale fts index parameters',
default="", default='',
) )

View File

@ -10,26 +10,26 @@ class OpenSearchConfig(BaseSettings):
""" """
OPENSEARCH_HOST: Optional[str] = Field( OPENSEARCH_HOST: Optional[str] = Field(
description="OpenSearch host", description='OpenSearch host',
default=None, default=None,
) )
OPENSEARCH_PORT: PositiveInt = Field( OPENSEARCH_PORT: PositiveInt = Field(
description="OpenSearch port", description='OpenSearch port',
default=9200, default=9200,
) )
OPENSEARCH_USER: Optional[str] = Field( OPENSEARCH_USER: Optional[str] = Field(
description="OpenSearch user", description='OpenSearch user',
default=None, default=None,
) )
OPENSEARCH_PASSWORD: Optional[str] = Field( OPENSEARCH_PASSWORD: Optional[str] = Field(
description="OpenSearch password", description='OpenSearch password',
default=None, default=None,
) )
OPENSEARCH_SECURE: bool = Field( OPENSEARCH_SECURE: bool = Field(
description="whether to use SSL connection for OpenSearch", description='whether to use SSL connection for OpenSearch',
default=False, default=False,
) )

View File

@ -10,26 +10,26 @@ class OracleConfig(BaseSettings):
""" """
ORACLE_HOST: Optional[str] = Field( ORACLE_HOST: Optional[str] = Field(
description="ORACLE host", description='ORACLE host',
default=None, default=None,
) )
ORACLE_PORT: Optional[PositiveInt] = Field( ORACLE_PORT: Optional[PositiveInt] = Field(
description="ORACLE port", description='ORACLE port',
default=1521, default=1521,
) )
ORACLE_USER: Optional[str] = Field( ORACLE_USER: Optional[str] = Field(
description="ORACLE user", description='ORACLE user',
default=None, default=None,
) )
ORACLE_PASSWORD: Optional[str] = Field( ORACLE_PASSWORD: Optional[str] = Field(
description="ORACLE password", description='ORACLE password',
default=None, default=None,
) )
ORACLE_DATABASE: Optional[str] = Field( ORACLE_DATABASE: Optional[str] = Field(
description="ORACLE database", description='ORACLE database',
default=None, default=None,
) )

View File

@ -10,26 +10,26 @@ class PGVectorConfig(BaseSettings):
""" """
PGVECTOR_HOST: Optional[str] = Field( PGVECTOR_HOST: Optional[str] = Field(
description="PGVector host", description='PGVector host',
default=None, default=None,
) )
PGVECTOR_PORT: Optional[PositiveInt] = Field( PGVECTOR_PORT: Optional[PositiveInt] = Field(
description="PGVector port", description='PGVector port',
default=5433, default=5433,
) )
PGVECTOR_USER: Optional[str] = Field( PGVECTOR_USER: Optional[str] = Field(
description="PGVector user", description='PGVector user',
default=None, default=None,
) )
PGVECTOR_PASSWORD: Optional[str] = Field( PGVECTOR_PASSWORD: Optional[str] = Field(
description="PGVector password", description='PGVector password',
default=None, default=None,
) )
PGVECTOR_DATABASE: Optional[str] = Field( PGVECTOR_DATABASE: Optional[str] = Field(
description="PGVector database", description='PGVector database',
default=None, default=None,
) )

View File

@ -10,26 +10,26 @@ class PGVectoRSConfig(BaseSettings):
""" """
PGVECTO_RS_HOST: Optional[str] = Field( PGVECTO_RS_HOST: Optional[str] = Field(
description="PGVectoRS host", description='PGVectoRS host',
default=None, default=None,
) )
PGVECTO_RS_PORT: Optional[PositiveInt] = Field( PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
description="PGVectoRS port", description='PGVectoRS port',
default=5431, default=5431,
) )
PGVECTO_RS_USER: Optional[str] = Field( PGVECTO_RS_USER: Optional[str] = Field(
description="PGVectoRS user", description='PGVectoRS user',
default=None, default=None,
) )
PGVECTO_RS_PASSWORD: Optional[str] = Field( PGVECTO_RS_PASSWORD: Optional[str] = Field(
description="PGVectoRS password", description='PGVectoRS password',
default=None, default=None,
) )
PGVECTO_RS_DATABASE: Optional[str] = Field( PGVECTO_RS_DATABASE: Optional[str] = Field(
description="PGVectoRS database", description='PGVectoRS database',
default=None, default=None,
) )

View File

@ -10,26 +10,26 @@ class QdrantConfig(BaseSettings):
""" """
QDRANT_URL: Optional[str] = Field( QDRANT_URL: Optional[str] = Field(
description="Qdrant url", description='Qdrant url',
default=None, default=None,
) )
QDRANT_API_KEY: Optional[str] = Field( QDRANT_API_KEY: Optional[str] = Field(
description="Qdrant api key", description='Qdrant api key',
default=None, default=None,
) )
QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field( QDRANT_CLIENT_TIMEOUT: NonNegativeInt = Field(
description="Qdrant client timeout in seconds", description='Qdrant client timeout in seconds',
default=20, default=20,
) )
QDRANT_GRPC_ENABLED: bool = Field( QDRANT_GRPC_ENABLED: bool = Field(
description="whether enable grpc support for Qdrant connection", description='whether enable grpc support for Qdrant connection',
default=False, default=False,
) )
QDRANT_GRPC_PORT: PositiveInt = Field( QDRANT_GRPC_PORT: PositiveInt = Field(
description="Qdrant grpc port", description='Qdrant grpc port',
default=6334, default=6334,
) )

View File

@ -10,26 +10,26 @@ class RelytConfig(BaseSettings):
""" """
RELYT_HOST: Optional[str] = Field( RELYT_HOST: Optional[str] = Field(
description="Relyt host", description='Relyt host',
default=None, default=None,
) )
RELYT_PORT: PositiveInt = Field( RELYT_PORT: PositiveInt = Field(
description="Relyt port", description='Relyt port',
default=9200, default=9200,
) )
RELYT_USER: Optional[str] = Field( RELYT_USER: Optional[str] = Field(
description="Relyt user", description='Relyt user',
default=None, default=None,
) )
RELYT_PASSWORD: Optional[str] = Field( RELYT_PASSWORD: Optional[str] = Field(
description="Relyt password", description='Relyt password',
default=None, default=None,
) )
RELYT_DATABASE: Optional[str] = Field( RELYT_DATABASE: Optional[str] = Field(
description="Relyt database", description='Relyt database',
default="default", default='default',
) )

View File

@ -10,41 +10,41 @@ class TencentVectorDBConfig(BaseSettings):
""" """
TENCENT_VECTOR_DB_URL: Optional[str] = Field( TENCENT_VECTOR_DB_URL: Optional[str] = Field(
description="Tencent Vector URL", description='Tencent Vector URL',
default=None, default=None,
) )
TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field( TENCENT_VECTOR_DB_API_KEY: Optional[str] = Field(
description="Tencent Vector API key", description='Tencent Vector API key',
default=None, default=None,
) )
TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field( TENCENT_VECTOR_DB_TIMEOUT: PositiveInt = Field(
description="Tencent Vector timeout in seconds", description='Tencent Vector timeout in seconds',
default=30, default=30,
) )
TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field( TENCENT_VECTOR_DB_USERNAME: Optional[str] = Field(
description="Tencent Vector username", description='Tencent Vector username',
default=None, default=None,
) )
TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field( TENCENT_VECTOR_DB_PASSWORD: Optional[str] = Field(
description="Tencent Vector password", description='Tencent Vector password',
default=None, default=None,
) )
TENCENT_VECTOR_DB_SHARD: PositiveInt = Field( TENCENT_VECTOR_DB_SHARD: PositiveInt = Field(
description="Tencent Vector sharding number", description='Tencent Vector sharding number',
default=1, default=1,
) )
TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field( TENCENT_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Tencent Vector replicas", description='Tencent Vector replicas',
default=2, default=2,
) )
TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field( TENCENT_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Tencent Vector Database", description='Tencent Vector Database',
default=None, default=None,
) )

View File

@ -10,26 +10,26 @@ class TiDBVectorConfig(BaseSettings):
""" """
TIDB_VECTOR_HOST: Optional[str] = Field( TIDB_VECTOR_HOST: Optional[str] = Field(
description="TiDB Vector host", description='TiDB Vector host',
default=None, default=None,
) )
TIDB_VECTOR_PORT: Optional[PositiveInt] = Field( TIDB_VECTOR_PORT: Optional[PositiveInt] = Field(
description="TiDB Vector port", description='TiDB Vector port',
default=4000, default=4000,
) )
TIDB_VECTOR_USER: Optional[str] = Field( TIDB_VECTOR_USER: Optional[str] = Field(
description="TiDB Vector user", description='TiDB Vector user',
default=None, default=None,
) )
TIDB_VECTOR_PASSWORD: Optional[str] = Field( TIDB_VECTOR_PASSWORD: Optional[str] = Field(
description="TiDB Vector password", description='TiDB Vector password',
default=None, default=None,
) )
TIDB_VECTOR_DATABASE: Optional[str] = Field( TIDB_VECTOR_DATABASE: Optional[str] = Field(
description="TiDB Vector database", description='TiDB Vector database',
default=None, default=None,
) )

View File

@ -10,21 +10,21 @@ class WeaviateConfig(BaseSettings):
""" """
WEAVIATE_ENDPOINT: Optional[str] = Field( WEAVIATE_ENDPOINT: Optional[str] = Field(
description="Weaviate endpoint URL", description='Weaviate endpoint URL',
default=None, default=None,
) )
WEAVIATE_API_KEY: Optional[str] = Field( WEAVIATE_API_KEY: Optional[str] = Field(
description="Weaviate API key", description='Weaviate API key',
default=None, default=None,
) )
WEAVIATE_GRPC_ENABLED: bool = Field( WEAVIATE_GRPC_ENABLED: bool = Field(
description="whether to enable gRPC for Weaviate connection", description='whether to enable gRPC for Weaviate connection',
default=True, default=True,
) )
WEAVIATE_BATCH_SIZE: PositiveInt = Field( WEAVIATE_BATCH_SIZE: PositiveInt = Field(
description="Weaviate batch size", description='Weaviate batch size',
default=100, default=100,
) )

View File

@ -8,11 +8,11 @@ class PackagingInfo(BaseSettings):
""" """
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description='Dify version',
default="0.8.1", default='0.7.0',
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(
description="SHA-1 checksum of the git commit used to build the app", description="SHA-1 checksum of the git commit used to build the app",
default="", default='',
) )

File diff suppressed because one or more lines are too long

View File

@ -1 +1,3 @@

View File

@ -2,7 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
bp = Blueprint("console", __name__, url_prefix="/console/api") bp = Blueprint('console', __name__, url_prefix='/console/api')
api = ExternalApi(bp) api = ExternalApi(bp)
# Import other controllers # Import other controllers

View File

@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp
def admin_required(view): def admin_required(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not os.getenv("ADMIN_API_KEY"): if not os.getenv('ADMIN_API_KEY'):
raise Unauthorized("API key is invalid.") raise Unauthorized('API key is invalid.')
auth_header = request.headers.get("Authorization") auth_header = request.headers.get('Authorization')
if auth_header is None: if auth_header is None:
raise Unauthorized("Authorization header is missing.") raise Unauthorized('Authorization header is missing.')
if " " not in auth_header: if ' ' not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower() auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer": if auth_scheme != 'bearer':
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
if os.getenv("ADMIN_API_KEY") != auth_token: if os.getenv('ADMIN_API_KEY') != auth_token:
raise Unauthorized("API key is invalid.") raise Unauthorized('API key is invalid.')
return view(*args, **kwargs) return view(*args, **kwargs)
@ -44,33 +44,37 @@ class InsertExploreAppListApi(Resource):
@admin_required @admin_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") parser.add_argument('app_id', type=str, required=True, nullable=False, location='json')
parser.add_argument("desc", type=str, location="json") parser.add_argument('desc', type=str, location='json')
parser.add_argument("copyright", type=str, location="json") parser.add_argument('copyright', type=str, location='json')
parser.add_argument("privacy_policy", type=str, location="json") parser.add_argument('privacy_policy', type=str, location='json')
parser.add_argument("custom_disclaimer", type=str, location="json") parser.add_argument('custom_disclaimer', type=str, location='json')
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json')
parser.add_argument("category", type=str, required=True, nullable=False, location="json") parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument("position", type=int, required=True, nullable=False, location="json") parser.add_argument('position', type=int, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
app = App.query.filter(App.id == args["app_id"]).first() app = App.query.filter(App.id == args['app_id']).first()
if not app: if not app:
raise NotFound(f'App \'{args["app_id"]}\' is not found') raise NotFound(f'App \'{args["app_id"]}\' is not found')
site = app.site site = app.site
if not site: if not site:
desc = args["desc"] or "" desc = args['desc'] if args['desc'] else ''
copy_right = args["copyright"] or "" copy_right = args['copyright'] if args['copyright'] else ''
privacy_policy = args["privacy_policy"] or "" privacy_policy = args['privacy_policy'] if args['privacy_policy'] else ''
custom_disclaimer = args["custom_disclaimer"] or "" custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else ''
else: else:
desc = site.description or args["desc"] or "" desc = site.description if site.description else \
copy_right = site.copyright or args["copyright"] or "" args['desc'] if args['desc'] else ''
privacy_policy = site.privacy_policy or args["privacy_policy"] or "" copy_right = site.copyright if site.copyright else \
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" args['copyright'] if args['copyright'] else ''
privacy_policy = site.privacy_policy if site.privacy_policy else \
args['privacy_policy'] if args['privacy_policy'] else ''
custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \
args['custom_disclaimer'] if args['custom_disclaimer'] else ''
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
if not recommended_app: if not recommended_app:
recommended_app = RecommendedApp( recommended_app = RecommendedApp(
@ -79,9 +83,9 @@ class InsertExploreAppListApi(Resource):
copyright=copy_right, copyright=copy_right,
privacy_policy=privacy_policy, privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer, custom_disclaimer=custom_disclaimer,
language=args["language"], language=args['language'],
category=args["category"], category=args['category'],
position=args["position"], position=args['position']
) )
db.session.add(recommended_app) db.session.add(recommended_app)
@ -89,21 +93,21 @@ class InsertExploreAppListApi(Resource):
app.is_public = True app.is_public = True
db.session.commit() db.session.commit()
return {"result": "success"}, 201 return {'result': 'success'}, 201
else: else:
recommended_app.description = desc recommended_app.description = desc
recommended_app.copyright = copy_right recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"] recommended_app.language = args['language']
recommended_app.category = args["category"] recommended_app.category = args['category']
recommended_app.position = args["position"] recommended_app.position = args['position']
app.is_public = True app.is_public = True
db.session.commit() db.session.commit()
return {"result": "success"}, 200 return {'result': 'success'}, 200
class InsertExploreAppApi(Resource): class InsertExploreAppApi(Resource):
@ -112,14 +116,15 @@ class InsertExploreAppApi(Resource):
def delete(self, app_id): def delete(self, app_id):
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
if not recommended_app: if not recommended_app:
return {"result": "success"}, 204 return {'result': 'success'}, 204
app = App.query.filter(App.id == recommended_app.app_id).first() app = App.query.filter(App.id == recommended_app.app_id).first()
if app: if app:
app.is_public = False app.is_public = False
installed_apps = InstalledApp.query.filter( installed_apps = InstalledApp.query.filter(
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
).all() ).all()
for installed_app in installed_apps: for installed_app in installed_apps:
@ -128,8 +133,8 @@ class InsertExploreAppApi(Resource):
db.session.delete(recommended_app) db.session.delete(recommended_app)
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {'result': 'success'}, 204
api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps')
api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>") api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/<uuid:app_id>')

View File

@ -14,21 +14,26 @@ from .setup import setup_required
from .wraps import account_initialization_required from .wraps import account_initialization_required
api_key_fields = { api_key_fields = {
"id": fields.String, 'id': fields.String,
"type": fields.String, 'type': fields.String,
"token": fields.String, 'token': fields.String,
"last_used_at": TimestampField, 'last_used_at': TimestampField,
"created_at": TimestampField, 'created_at': TimestampField
} }
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} api_key_list = {
'data': fields.List(fields.Nested(api_key_fields), attribute="items")
}
def _get_resource(resource_id, tenant_id, resource_model): def _get_resource(resource_id, tenant_id, resource_model):
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() resource = resource_model.query.filter_by(
id=resource_id, tenant_id=tenant_id
).first()
if resource is None: if resource is None:
flask_restful.abort(404, message=f"{resource_model.__name__} not found.") flask_restful.abort(
404, message=f"{resource_model.__name__} not found.")
return resource return resource
@ -45,32 +50,30 @@ class BaseApiKeyListResource(Resource):
@marshal_with(api_key_list) @marshal_with(api_key_list)
def get(self, resource_id): def get(self, resource_id):
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id,
keys = ( self.resource_model)
db.session.query(ApiToken) keys = db.session.query(ApiToken). \
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
.all() all()
)
return {"items": keys} return {"items": keys}
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
def post(self, resource_id): def post(self, resource_id):
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id,
if not current_user.is_editor: self.resource_model)
if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
current_key_count = ( current_key_count = db.session.query(ApiToken). \
db.session.query(ApiToken) filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) count()
.count()
)
if current_key_count >= self.max_keys: if current_key_count >= self.max_keys:
flask_restful.abort( flask_restful.abort(
400, 400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded", code='max_keys_exceeded'
) )
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix, 24)
@ -94,78 +97,79 @@ class BaseApiKeyResource(Resource):
def delete(self, resource_id, api_key_id): def delete(self, resource_id, api_key_id):
resource_id = str(resource_id) resource_id = str(resource_id)
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id,
self.resource_model)
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
key = ( key = db.session.query(ApiToken). \
db.session.query(ApiToken) filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \
.filter( first()
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
)
if key is None: if key is None:
flask_restful.abort(404, message="API key not found") flask_restful.abort(404, message='API key not found')
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {'result': 'success'}, 204
class AppApiKeyListResource(BaseApiKeyListResource): class AppApiKeyListResource(BaseApiKeyListResource):
def after_request(self, resp): def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers['Access-Control-Allow-Credentials'] = 'true'
return resp return resp
resource_type = "app" resource_type = 'app'
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = 'app_id'
token_prefix = "app-" token_prefix = 'app-'
class AppApiKeyResource(BaseApiKeyResource): class AppApiKeyResource(BaseApiKeyResource):
def after_request(self, resp): def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers['Access-Control-Allow-Credentials'] = 'true'
return resp return resp
resource_type = "app" resource_type = 'app'
resource_model = App resource_model = App
resource_id_field = "app_id" resource_id_field = 'app_id'
class DatasetApiKeyListResource(BaseApiKeyListResource): class DatasetApiKeyListResource(BaseApiKeyListResource):
def after_request(self, resp): def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers['Access-Control-Allow-Credentials'] = 'true'
return resp return resp
resource_type = "dataset" resource_type = 'dataset'
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = 'dataset_id'
token_prefix = "ds-" token_prefix = 'ds-'
class DatasetApiKeyResource(BaseApiKeyResource): class DatasetApiKeyResource(BaseApiKeyResource):
def after_request(self, resp): def after_request(self, resp):
resp.headers["Access-Control-Allow-Origin"] = "*" resp.headers['Access-Control-Allow-Origin'] = '*'
resp.headers["Access-Control-Allow-Credentials"] = "true" resp.headers['Access-Control-Allow-Credentials'] = 'true'
return resp return resp
resource_type = 'dataset'
resource_type = "dataset"
resource_model = Dataset resource_model = Dataset
resource_id_field = "dataset_id" resource_id_field = 'dataset_id'
api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys") api.add_resource(AppApiKeyListResource, '/apps/<uuid:resource_id>/api-keys')
api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>") api.add_resource(AppApiKeyResource,
api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys") '/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>") api.add_resource(DatasetApiKeyListResource,
'/datasets/<uuid:resource_id>/api-keys')
api.add_resource(DatasetApiKeyResource,
'/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>')

View File

@ -8,18 +8,19 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ
class AdvancedPromptTemplateList(Resource): class AdvancedPromptTemplateList(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("app_mode", type=str, required=True, location="args") parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument("model_mode", type=str, required=True, location="args") parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument("has_context", type=str, required=False, default="true", location="args") parser.add_argument('has_context', type=str, required=False, default='true', location='args')
parser.add_argument("model_name", type=str, required=True, location="args") parser.add_argument('model_name', type=str, required=True, location='args')
args = parser.parse_args() args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args) return AdvancedPromptTemplateService.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates")

View File

@ -18,12 +18,15 @@ class AgentLogApi(Resource):
def get(self, app_model): def get(self, app_model):
"""Get agent logs""" """Get agent logs"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("message_id", type=uuid_value, required=True, location="args") parser.add_argument('message_id', type=uuid_value, required=True, location='args')
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") parser.add_argument('conversation_id', type=uuid_value, required=True, location='args')
args = parser.parse_args() args = parser.parse_args()
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) return AgentService.get_agent_logs(
app_model,
args['conversation_id'],
api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs") args['message_id']
)
api.add_resource(AgentLogApi, '/apps/<uuid:app_id>/agent/logs')

View File

@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check('annotation')
def post(self, app_id, action): def post(self, app_id, action):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json") parser.add_argument('score_threshold', required=True, type=float, location='json')
parser.add_argument("embedding_provider_name", required=True, type=str, location="json") parser.add_argument('embedding_provider_name', required=True, type=str, location='json')
parser.add_argument("embedding_model_name", required=True, type=str, location="json") parser.add_argument('embedding_model_name', required=True, type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
if action == "enable": if action == 'enable':
result = AppAnnotationService.enable_app_annotation(args, app_id) result = AppAnnotationService.enable_app_annotation(args, app_id)
elif action == "disable": elif action == 'disable':
result = AppAnnotationService.disable_app_annotation(app_id) result = AppAnnotationService.disable_app_annotation(app_id)
else: else:
raise ValueError("Unsupported annotation reply action") raise ValueError('Unsupported annotation reply action')
return result, 200 return result, 200
@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource):
annotation_setting_id = str(annotation_setting_id) annotation_setting_id = str(annotation_setting_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json") parser.add_argument('score_threshold', required=True, type=float, location='json')
args = parser.parse_args() args = parser.parse_args()
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
@ -77,24 +77,28 @@ class AnnotationReplyActionStatusApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id, action): def get(self, app_id, job_id, action):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key) cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job is not exist.")
job_status = cache_result.decode() job_status = cache_result.decode()
error_msg = "" error_msg = ''
if job_status == "error": if job_status == 'error':
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode() error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationListApi(Resource): class AnnotationListApi(Resource):
@ -105,18 +109,18 @@ class AnnotationListApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get('page', default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get('limit', default=20, type=int)
keyword = request.args.get("keyword", default=None, type=str) keyword = request.args.get('keyword', default=None, type=str)
app_id = str(app_id) app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
response = { response = {
"data": marshal(annotation_list, annotation_fields), 'data': marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit, 'has_more': len(annotation_list) == limit,
"limit": limit, 'limit': limit,
"total": total, 'total': total,
"page": page, 'page': page
} }
return response, 200 return response, 200
@ -131,7 +135,9 @@ class AnnotationExportApi(Resource):
app_id = str(app_id) app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
response = {"data": marshal(annotation_list, annotation_fields)} response = {
'data': marshal(annotation_list, annotation_fields)
}
return response, 200 return response, 200
@ -139,7 +145,7 @@ class AnnotationCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id): def post(self, app_id):
if not current_user.is_editor: if not current_user.is_editor:
@ -147,8 +153,8 @@ class AnnotationCreateApi(Resource):
app_id = str(app_id) app_id = str(app_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json") parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument("answer", required=True, type=str, location="json") parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
return annotation return annotation
@ -158,7 +164,7 @@ class AnnotationUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check('annotation')
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_id, annotation_id): def post(self, app_id, annotation_id):
if not current_user.is_editor: if not current_user.is_editor:
@ -167,8 +173,8 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json") parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument("answer", required=True, type=str, location="json") parser.add_argument('answer', required=True, type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
return annotation return annotation
@ -183,29 +189,29 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id)
return {"result": "success"}, 200 return {'result': 'success'}, 200
class AnnotationBatchImportApi(Resource): class AnnotationBatchImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check('annotation')
def post(self, app_id): def post(self, app_id):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
app_id = str(app_id) app_id = str(app_id)
# get file from request # get file from request
file = request.files["file"] file = request.files['file']
# check file # check file
if "file" not in request.files: if 'file' not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
# check file type # check file type
if not file.filename.endswith(".csv"): if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed") raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file) return AppAnnotationService.batch_import_app_annotations(app_id, file)
@ -214,23 +220,27 @@ class AnnotationBatchImportStatusApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check('annotation')
def get(self, app_id, job_id): def get(self, app_id, job_id):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
job_id = str(job_id) job_id = str(job_id)
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job is not exist.")
job_status = cache_result.decode() job_status = cache_result.decode()
error_msg = "" error_msg = ''
if job_status == "error": if job_status == 'error':
indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id))
error_msg = redis_client.get(indexing_error_msg_key).decode() error_msg = redis_client.get(indexing_error_msg_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 return {
'job_id': job_id,
'job_status': job_status,
'error_msg': error_msg
}, 200
class AnnotationHitHistoryListApi(Resource): class AnnotationHitHistoryListApi(Resource):
@ -241,32 +251,30 @@ class AnnotationHitHistoryListApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
page = request.args.get("page", default=1, type=int) page = request.args.get('page', default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get('limit', default=20, type=int)
app_id = str(app_id) app_id = str(app_id)
annotation_id = str(annotation_id) annotation_id = str(annotation_id)
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id,
app_id, annotation_id, page, limit page, limit)
)
response = { response = {
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields), 'data': marshal(annotation_hit_history_list, annotation_hit_history_fields),
"has_more": len(annotation_hit_history_list) == limit, 'has_more': len(annotation_hit_history_list) == limit,
"limit": limit, 'limit': limit,
"total": total, 'total': total,
"page": page, 'page': page
} }
return response return response
api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>") api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>')
api.add_resource( api.add_resource(AnnotationReplyActionStatusApi,
AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>" '/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>')
) api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations") api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export')
api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export") api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>')
api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>") api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import')
api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import") api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>')
api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>") api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories')
api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories") api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting')
api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting") api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>')
api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")

View File

@ -18,35 +18,27 @@ from libs.login import login_required
from services.app_dsl_service import AppDslService from services.app_dsl_service import AppDslService
from services.app_service import AppService from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
class AppListApi(Resource): class AppListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
"""Get app list""" """Get app list"""
def uuid_list(value): def uuid_list(value):
try: try:
return [str(uuid.UUID(v)) for v in value.split(",")] return [str(uuid.UUID(v)) for v in value.split(',')]
except ValueError: except ValueError:
abort(400, message="Invalid UUID format in tag_ids.") abort(400, message="Invalid UUID format in tag_ids.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
parser.add_argument( parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
"mode", parser.add_argument('name', type=str, location='args', required=False)
type=str, parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
choices=["chat", "workflow", "agent-chat", "channel", "all"],
default="all",
location="args",
required=False,
)
parser.add_argument("name", type=str, location="args", required=False)
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
args = parser.parse_args() args = parser.parse_args()
@ -54,7 +46,7 @@ class AppListApi(Resource):
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
return marshal(app_pagination, app_pagination_fields) return marshal(app_pagination, app_pagination_fields)
@ -62,23 +54,23 @@ class AppListApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_detail_fields) @marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check('apps')
def post(self): def post(self):
"""Create app""" """Create app"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument("description", type=str, location="json") parser.add_argument('description', type=str, location='json')
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json')
parser.add_argument("icon_type", type=str, location="json") parser.add_argument('icon_type', type=str, location='json')
parser.add_argument("icon", type=str, location="json") parser.add_argument('icon', type=str, location='json')
parser.add_argument("icon_background", type=str, location="json") parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if "mode" not in args or args["mode"] is None: if 'mode' not in args or args['mode'] is None:
raise BadRequest("mode is required") raise BadRequest("mode is required")
app_service = AppService() app_service = AppService()
@ -92,7 +84,7 @@ class AppImportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_detail_fields_with_site) @marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check('apps')
def post(self): def post(self):
"""Import app""" """Import app"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -100,16 +92,19 @@ class AppImportApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json") parser.add_argument('data', type=str, required=True, nullable=False, location='json')
parser.add_argument("name", type=str, location="json") parser.add_argument('name', type=str, location='json')
parser.add_argument("description", type=str, location="json") parser.add_argument('description', type=str, location='json')
parser.add_argument("icon_type", type=str, location="json") parser.add_argument('icon_type', type=str, location='json')
parser.add_argument("icon", type=str, location="json") parser.add_argument('icon', type=str, location='json')
parser.add_argument("icon_background", type=str, location="json") parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
app = AppDslService.import_and_create_new_app( app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user tenant_id=current_user.current_tenant_id,
data=args['data'],
args=args,
account=current_user
) )
return app, 201 return app, 201
@ -120,7 +115,7 @@ class AppImportFromUrlApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_detail_fields_with_site) @marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check('apps')
def post(self): def post(self):
"""Import app from url""" """Import app from url"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -128,21 +123,25 @@ class AppImportFromUrlApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="json") parser.add_argument('url', type=str, required=True, nullable=False, location='json')
parser.add_argument("name", type=str, location="json") parser.add_argument('name', type=str, location='json')
parser.add_argument("description", type=str, location="json") parser.add_argument('description', type=str, location='json')
parser.add_argument("icon", type=str, location="json") parser.add_argument('icon', type=str, location='json')
parser.add_argument("icon_background", type=str, location="json") parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
app = AppDslService.import_and_create_new_app_from_url( app = AppDslService.import_and_create_new_app_from_url(
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user tenant_id=current_user.current_tenant_id,
url=args['url'],
args=args,
account=current_user
) )
return app, 201 return app, 201
class AppApi(Resource): class AppApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -166,15 +165,14 @@ class AppApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument("description", type=str, location="json") parser.add_argument('description', type=str, location='json')
parser.add_argument("icon_type", type=str, location="json") parser.add_argument('icon_type', type=str, location='json')
parser.add_argument("icon", type=str, location="json") parser.add_argument('icon', type=str, location='json')
parser.add_argument("icon_background", type=str, location="json") parser.add_argument('icon_background', type=str, location='json')
parser.add_argument("max_active_requests", type=int, location="json") parser.add_argument('max_active_requests', type=int, location='json')
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
@ -195,7 +193,7 @@ class AppApi(Resource):
app_service = AppService() app_service = AppService()
app_service.delete_app(app_model) app_service.delete_app(app_model)
return {"result": "success"}, 204 return {'result': 'success'}, 204
class AppCopyApi(Resource): class AppCopyApi(Resource):
@ -211,16 +209,19 @@ class AppCopyApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, location="json") parser.add_argument('name', type=str, location='json')
parser.add_argument("description", type=str, location="json") parser.add_argument('description', type=str, location='json')
parser.add_argument("icon_type", type=str, location="json") parser.add_argument('icon_type', type=str, location='json')
parser.add_argument("icon", type=str, location="json") parser.add_argument('icon', type=str, location='json')
parser.add_argument("icon_background", type=str, location="json") parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
data = AppDslService.export_dsl(app_model=app_model, include_secret=True) data = AppDslService.export_dsl(app_model=app_model, include_secret=True)
app = AppDslService.import_and_create_new_app( app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user tenant_id=current_user.current_tenant_id,
data=data,
args=args,
account=current_user
) )
return app, 201 return app, 201
@ -239,10 +240,12 @@ class AppExportApi(Resource):
# Add include_secret params # Add include_secret params
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args')
args = parser.parse_args() args = parser.parse_args()
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])} return {
"data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret'])
}
class AppNameApi(Resource): class AppNameApi(Resource):
@ -255,13 +258,13 @@ class AppNameApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument('name', type=str, required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_name(app_model, args.get("name")) app_model = app_service.update_app_name(app_model, args.get('name'))
return app_model return app_model
@ -276,14 +279,14 @@ class AppIconApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("icon", type=str, location="json") parser.add_argument('icon', type=str, location='json')
parser.add_argument("icon_background", type=str, location="json") parser.add_argument('icon_background', type=str, location='json')
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background'))
return app_model return app_model
@ -298,13 +301,13 @@ class AppSiteStatus(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("enable_site", type=bool, required=True, location="json") parser.add_argument('enable_site', type=bool, required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) app_model = app_service.update_app_site_status(app_model, args.get('enable_site'))
return app_model return app_model
@ -319,13 +322,13 @@ class AppApiStatus(Resource):
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("enable_api", type=bool, required=True, location="json") parser.add_argument('enable_api', type=bool, required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) app_model = app_service.update_app_api_status(app_model, args.get('enable_api'))
return app_model return app_model
@ -336,7 +339,9 @@ class AppTraceApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
"""Get app trace""" """Get app trace"""
app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id) app_trace_config = OpsTraceManager.get_app_tracing_config(
app_id=app_id
)
return app_trace_config return app_trace_config
@ -348,27 +353,27 @@ class AppTraceApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("enabled", type=bool, required=True, location="json") parser.add_argument('enabled', type=bool, required=True, location='json')
parser.add_argument("tracing_provider", type=str, required=True, location="json") parser.add_argument('tracing_provider', type=str, required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
OpsTraceManager.update_app_tracing_config( OpsTraceManager.update_app_tracing_config(
app_id=app_id, app_id=app_id,
enabled=args["enabled"], enabled=args['enabled'],
tracing_provider=args["tracing_provider"], tracing_provider=args['tracing_provider'],
) )
return {"result": "success"} return {"result": "success"}
api.add_resource(AppListApi, "/apps") api.add_resource(AppListApi, '/apps')
api.add_resource(AppImportApi, "/apps/import") api.add_resource(AppImportApi, '/apps/import')
api.add_resource(AppImportFromUrlApi, "/apps/import/url") api.add_resource(AppImportFromUrlApi, '/apps/import/url')
api.add_resource(AppApi, "/apps/<uuid:app_id>") api.add_resource(AppApi, '/apps/<uuid:app_id>')
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy") api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy')
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export") api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export')
api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name") api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name')
api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon") api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon')
api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable") api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable')
api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable") api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable')
api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace") api.add_resource(AppTraceApi, '/apps/<uuid:app_id>/trace')

View File

@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model): def post(self, app_model):
file = request.files["file"] file = request.files['file']
try: try:
response = AudioService.transcript_asr( response = AudioService.transcript_asr(
@ -85,27 +85,31 @@ class ChatMessageTextApi(Resource):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, location="json") parser.add_argument('message_id', type=str, location='json')
parser.add_argument("text", type=str, location="json") parser.add_argument('text', type=str, location='json')
parser.add_argument("voice", type=str, location="json") parser.add_argument('voice', type=str, location='json')
parser.add_argument("streaming", type=bool, location="json") parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args() args = parser.parse_args()
message_id = args.get("message_id", None) message_id = args.get('message_id', None)
text = args.get("text", None) text = args.get('text', None)
if ( if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] and app_model.workflow
and app_model.workflow and app_model.workflow.features_dict):
and app_model.workflow.features_dict text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
): voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get(
'voice')
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) response = AudioService.transcript_tts(
app_model=app_model,
text=text,
message_id=message_id,
voice=voice
)
return response return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
@ -141,12 +145,12 @@ class TextModesApi(Resource):
def get(self, app_model): def get(self, app_model):
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("language", type=str, required=True, location="args") parser.add_argument('language', type=str, required=True, location='args')
args = parser.parse_args() args = parser.parse_args()
response = AudioService.transcript_tts_voices( response = AudioService.transcript_tts_voices(
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
language=args["language"], language=args['language'],
) )
return response return response
@ -175,6 +179,6 @@ class TextModesApi(Resource):
raise InternalServerError() raise InternalServerError()
api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text") api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text')
api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio") api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio')
api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices") api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices')

View File

@ -17,7 +17,6 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
@ -32,33 +31,37 @@ from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
# define completion message api for user # define completion message api for user
class CompletionMessageApi(Resource): class CompletionMessageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument("query", type=str, location="json", default="") parser.add_argument('query', type=str, location='json', default='')
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument("model_config", type=dict, required=True, location="json") parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args['response_mode'] != 'blocking'
args["auto_generate_name"] = False args['auto_generate_name'] = False
account = flask_login.current_user account = flask_login.current_user
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming app_model=app_model,
user=account,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -94,7 +97,7 @@ class CompletionMessageStopApi(Resource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {"result": "success"}, 200 return {'result': 'success'}, 200
class ChatMessageApi(Resource): class ChatMessageApi(Resource):
@ -104,23 +107,27 @@ class ChatMessageApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument("query", type=str, required=True, location="json") parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument("model_config", type=dict, required=True, location="json") parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] != "blocking" streaming = args['response_mode'] != 'blocking'
args["auto_generate_name"] = False args['auto_generate_name'] = False
account = flask_login.current_user account = flask_login.current_user
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming app_model=app_model,
user=account,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=streaming
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -137,8 +144,6 @@ class ChatMessageApi(Resource):
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except (ValueError, AppInvokeQuotaExceededError) as e: except (ValueError, AppInvokeQuotaExceededError) as e:
@ -158,10 +163,10 @@ class ChatMessageStopApi(Resource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id)
return {"result": "success"}, 200 return {'result': 'success'}, 200
api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages") api.add_resource(CompletionMessageApi, '/apps/<uuid:app_id>/completion-messages')
api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop") api.add_resource(CompletionMessageStopApi, '/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop')
api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages") api.add_resource(ChatMessageApi, '/apps/<uuid:app_id>/chat-messages')
api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop") api.add_resource(ChatMessageStopApi, '/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop')

View File

@ -20,12 +20,13 @@ from fields.conversation_fields import (
conversation_pagination_fields, conversation_pagination_fields,
conversation_with_summary_pagination_fields, conversation_with_summary_pagination_fields,
) )
from libs.helper import DatetimeString from libs.helper import datetime_string
from libs.login import login_required from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
class CompletionConversationApi(Resource): class CompletionConversationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -35,23 +36,24 @@ class CompletionConversationApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument('keyword', type=str, location='args')
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument( parser.add_argument('annotation_status', type=str,
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
) parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion')
if args["keyword"]: if args['keyword']:
query = query.join(Message, Message.conversation_id == Conversation.id).filter( query = query.join(
Message, Message.conversation_id == Conversation.id
).filter(
or_( or_(
Message.query.ilike("%{}%".format(args["keyword"])), Message.query.ilike('%{}%'.format(args['keyword'])),
Message.answer.ilike("%{}%".format(args["keyword"])), Message.answer.ilike('%{}%'.format(args['keyword']))
) )
) )
@ -59,8 +61,8 @@ class CompletionConversationApi(Resource):
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
@ -68,8 +70,8 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at >= start_datetime_utc) query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=59) end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
@ -77,25 +79,29 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc) query = query.where(Conversation.created_at < end_datetime_utc)
if args["annotation_status"] == "annotated": if args['annotation_status'] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( query = query.options(joinedload(Conversation.message_annotations)).join(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args["annotation_status"] == "not_annotated": elif args['annotation_status'] == "not_annotated":
query = ( query = query.outerjoin(
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
.group_by(Conversation.id) ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
.having(func.count(MessageAnnotation.id) == 0)
)
query = query.order_by(Conversation.created_at.desc()) query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) conversations = db.paginate(
query,
page=args['page'],
per_page=args['limit'],
error_out=False
)
return conversations return conversations
class CompletionConversationDetailApi(Resource): class CompletionConversationDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -117,11 +123,8 @@ class CompletionConversationDetailApi(Resource):
raise Forbidden() raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
conversation = ( conversation = db.session.query(Conversation) \
db.session.query(Conversation) .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -129,10 +132,11 @@ class CompletionConversationDetailApi(Resource):
conversation.is_deleted = True conversation.is_deleted = True
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {'result': 'success'}, 204
class ChatConversationApi(Resource): class ChatConversationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -142,28 +146,20 @@ class ChatConversationApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument('keyword', type=str, location='args')
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument( parser.add_argument('annotation_status', type=str,
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" choices=['annotated', 'not_annotated', 'all'], default='all', location='args')
) parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args')
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args')
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
parser.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
args = parser.parse_args() args = parser.parse_args()
subquery = ( subquery = (
db.session.query( db.session.query(
Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") Conversation.id.label('conversation_id'),
EndUser.session_id.label('from_end_user_session_id')
) )
.outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id)
.subquery() .subquery()
@ -171,95 +167,78 @@ class ChatConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app_model.id) query = db.select(Conversation).where(Conversation.app_id == app_model.id)
if args["keyword"]: if args['keyword']:
keyword_filter = "%{}%".format(args["keyword"]) keyword_filter = '%{}%'.format(args['keyword'])
query = ( query = query.join(
query.join( Message, Message.conversation_id == Conversation.id,
Message, ).join(
Message.conversation_id == Conversation.id, subquery, subquery.c.conversation_id == Conversation.id
) ).filter(
.join(subquery, subquery.c.conversation_id == Conversation.id) or_(
.filter( Message.query.ilike(keyword_filter),
or_( Message.answer.ilike(keyword_filter),
Message.query.ilike(keyword_filter), Conversation.name.ilike(keyword_filter),
Message.answer.ilike(keyword_filter), Conversation.introduction.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter), subquery.c.from_end_user_session_id.ilike(keyword_filter)
Conversation.introduction.ilike(keyword_filter), ),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
),
)
) )
account = current_user account = current_user
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]: query = query.where(Conversation.created_at >= start_datetime_utc)
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=59) end_datetime = end_datetime.replace(second=59)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
match args["sort_by"]: query = query.where(Conversation.created_at < end_datetime_utc)
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated": if args['annotation_status'] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( query = query.options(joinedload(Conversation.message_annotations)).join(
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args["annotation_status"] == "not_annotated": elif args['annotation_status'] == "not_annotated":
query = ( query = query.outerjoin(
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
.group_by(Conversation.id) ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0)
.having(func.count(MessageAnnotation.id) == 0)
)
if args["message_count_gte"] and args["message_count_gte"] >= 1: if args['message_count_gte'] and args['message_count_gte'] >= 1:
query = ( query = (
query.options(joinedload(Conversation.messages)) query.options(joinedload(Conversation.messages))
.join(Message, Message.conversation_id == Conversation.id) .join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"]) .having(func.count(Message.id) >= args['message_count_gte'])
) )
if app_model.mode == AppMode.ADVANCED_CHAT.value: if app_model.mode == AppMode.ADVANCED_CHAT.value:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
match args["sort_by"]: query = query.order_by(Conversation.created_at.desc())
case "created_at":
query = query.order_by(Conversation.created_at.asc())
case "-created_at":
query = query.order_by(Conversation.created_at.desc())
case "updated_at":
query = query.order_by(Conversation.updated_at.asc())
case "-updated_at":
query = query.order_by(Conversation.updated_at.desc())
case _:
query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) conversations = db.paginate(
query,
page=args['page'],
per_page=args['limit'],
error_out=False
)
return conversations return conversations
class ChatConversationDetailApi(Resource): class ChatConversationDetailApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -281,11 +260,8 @@ class ChatConversationDetailApi(Resource):
raise Forbidden() raise Forbidden()
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
conversation = ( conversation = db.session.query(Conversation) \
db.session.query(Conversation) .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
@ -293,21 +269,18 @@ class ChatConversationDetailApi(Resource):
conversation.is_deleted = True conversation.is_deleted = True
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {'result': 'success'}, 204
api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations") api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>") api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations") api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>") api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>')
def _get_conversation(app_model, conversation_id): def _get_conversation(app_model, conversation_id):
conversation = ( conversation = db.session.query(Conversation) \
db.session.query(Conversation) .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first()
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")

View File

@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource):
@marshal_with(paginated_conversation_variable_fields) @marshal_with(paginated_conversation_variable_fields)
def get(self, app_model): def get(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("conversation_id", type=str, location="args") parser.add_argument('conversation_id', type=str, location='args')
args = parser.parse_args() args = parser.parse_args()
stmt = ( stmt = (
@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource):
.where(ConversationVariable.app_id == app_model.id) .where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at) .order_by(ConversationVariable.created_at)
) )
if args["conversation_id"]: if args['conversation_id']:
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"]) stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id'])
else: else:
raise ValueError("conversation_id is required") raise ValueError('conversation_id is required')
# NOTE: This is a temporary solution to avoid performance issues. # NOTE: This is a temporary solution to avoid performance issues.
page = 1 page = 1
@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource):
rows = session.scalars(stmt).all() rows = session.scalars(stmt).all()
return { return {
"page": page, 'page': page,
"limit": page_size, 'limit': page_size,
"total": len(rows), 'total': len(rows),
"has_more": False, 'has_more': False,
"data": [ 'data': [
{ {
"created_at": row.created_at, 'created_at': row.created_at,
"updated_at": row.updated_at, 'updated_at': row.updated_at,
**row.to_variable().model_dump(), **row.to_variable().model_dump(),
} }
for row in rows for row in rows
@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource):
} }
api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables") api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables')

View File

@ -2,128 +2,116 @@ from libs.exception import BaseHTTPException
class AppNotFoundError(BaseHTTPException): class AppNotFoundError(BaseHTTPException):
error_code = "app_not_found" error_code = 'app_not_found'
description = "App not found." description = "App not found."
code = 404 code = 404
class ProviderNotInitializeError(BaseHTTPException): class ProviderNotInitializeError(BaseHTTPException):
error_code = "provider_not_initialize" error_code = 'provider_not_initialize'
description = ( description = "No valid model provider credentials found. " \
"No valid model provider credentials found. " "Please go to Settings -> Model Provider to complete your provider credentials."
"Please go to Settings -> Model Provider to complete your provider credentials."
)
code = 400 code = 400
class ProviderQuotaExceededError(BaseHTTPException): class ProviderQuotaExceededError(BaseHTTPException):
error_code = "provider_quota_exceeded" error_code = 'provider_quota_exceeded'
description = ( description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
"Your quota for Dify Hosted Model Provider has been exhausted. " "Please go to Settings -> Model Provider to complete your own provider credentials."
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
code = 400 code = 400
class ProviderModelCurrentlyNotSupportError(BaseHTTPException): class ProviderModelCurrentlyNotSupportError(BaseHTTPException):
error_code = "model_currently_not_support" error_code = 'model_currently_not_support'
description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." description = "Dify Hosted OpenAI trial currently not support the GPT-4 model."
code = 400 code = 400
class ConversationCompletedError(BaseHTTPException): class ConversationCompletedError(BaseHTTPException):
error_code = "conversation_completed" error_code = 'conversation_completed'
description = "The conversation has ended. Please start a new conversation." description = "The conversation has ended. Please start a new conversation."
code = 400 code = 400
class AppUnavailableError(BaseHTTPException): class AppUnavailableError(BaseHTTPException):
error_code = "app_unavailable" error_code = 'app_unavailable'
description = "App unavailable, please check your app configurations." description = "App unavailable, please check your app configurations."
code = 400 code = 400
class CompletionRequestError(BaseHTTPException): class CompletionRequestError(BaseHTTPException):
error_code = "completion_request_error" error_code = 'completion_request_error'
description = "Completion request failed." description = "Completion request failed."
code = 400 code = 400
class AppMoreLikeThisDisabledError(BaseHTTPException): class AppMoreLikeThisDisabledError(BaseHTTPException):
error_code = "app_more_like_this_disabled" error_code = 'app_more_like_this_disabled'
description = "The 'More like this' feature is disabled. Please refresh your page." description = "The 'More like this' feature is disabled. Please refresh your page."
code = 403 code = 403
class NoAudioUploadedError(BaseHTTPException): class NoAudioUploadedError(BaseHTTPException):
error_code = "no_audio_uploaded" error_code = 'no_audio_uploaded'
description = "Please upload your audio." description = "Please upload your audio."
code = 400 code = 400
class AudioTooLargeError(BaseHTTPException): class AudioTooLargeError(BaseHTTPException):
error_code = "audio_too_large" error_code = 'audio_too_large'
description = "Audio size exceeded. {message}" description = "Audio size exceeded. {message}"
code = 413 code = 413
class UnsupportedAudioTypeError(BaseHTTPException): class UnsupportedAudioTypeError(BaseHTTPException):
error_code = "unsupported_audio_type" error_code = 'unsupported_audio_type'
description = "Audio type not allowed." description = "Audio type not allowed."
code = 415 code = 415
class ProviderNotSupportSpeechToTextError(BaseHTTPException): class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = "provider_not_support_speech_to_text" error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text." description = "Provider not support speech to text."
code = 400 code = 400
class NoFileUploadedError(BaseHTTPException): class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded" error_code = 'no_file_uploaded'
description = "Please upload your file." description = "Please upload your file."
code = 400 code = 400
class TooManyFilesError(BaseHTTPException): class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files" error_code = 'too_many_files'
description = "Only one file is allowed." description = "Only one file is allowed."
code = 400 code = 400
class DraftWorkflowNotExist(BaseHTTPException): class DraftWorkflowNotExist(BaseHTTPException):
error_code = "draft_workflow_not_exist" error_code = 'draft_workflow_not_exist'
description = "Draft workflow need to be initialized." description = "Draft workflow need to be initialized."
code = 400 code = 400
class DraftWorkflowNotSync(BaseHTTPException): class DraftWorkflowNotSync(BaseHTTPException):
error_code = "draft_workflow_not_sync" error_code = 'draft_workflow_not_sync'
description = "Workflow graph might have been modified, please refresh and resubmit." description = "Workflow graph might have been modified, please refresh and resubmit."
code = 400 code = 400
class TracingConfigNotExist(BaseHTTPException): class TracingConfigNotExist(BaseHTTPException):
error_code = "trace_config_not_exist" error_code = 'trace_config_not_exist'
description = "Trace config not exist." description = "Trace config not exist."
code = 400 code = 400
class TracingConfigIsExist(BaseHTTPException): class TracingConfigIsExist(BaseHTTPException):
error_code = "trace_config_is_exist" error_code = 'trace_config_is_exist'
description = "Trace config is exist." description = "Trace config is exist."
code = 400 code = 400
class TracingConfigCheckError(BaseHTTPException): class TracingConfigCheckError(BaseHTTPException):
error_code = "trace_config_check_error" error_code = 'trace_config_check_error'
description = "Invalid Credentials." description = "Invalid Credentials."
code = 400 code = 400
class InvokeRateLimitError(BaseHTTPException):
"""Raised when the Invoke returns rate limit error."""
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429

View File

@ -24,21 +24,21 @@ class RuleGenerateApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") parser.add_argument('instruction', type=str, required=True, nullable=False, location='json')
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json')
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") parser.add_argument('no_variable', type=bool, required=True, default=False, location='json')
args = parser.parse_args() args = parser.parse_args()
account = current_user account = current_user
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512'))
try: try:
rules = LLMGenerator.generate_rule_config( rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id, tenant_id=account.current_tenant_id,
instruction=args["instruction"], instruction=args['instruction'],
model_config=args["model_config"], model_config=args['model_config'],
no_variable=args["no_variable"], no_variable=args['no_variable'],
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -52,4 +52,4 @@ class RuleGenerateApi(Resource):
return rules return rules
api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleGenerateApi, '/rule-generate')

View File

@ -33,9 +33,9 @@ from services.message_service import MessageService
class ChatMessageListApi(Resource): class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = { message_infinite_scroll_pagination_fields = {
"limit": fields.Integer, 'limit': fields.Integer,
"has_more": fields.Boolean, 'has_more': fields.Boolean,
"data": fields.List(fields.Nested(message_detail_fields)), 'data': fields.List(fields.Nested(message_detail_fields))
} }
@setup_required @setup_required
@ -45,69 +45,55 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_fields) @marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model): def get(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument("first_id", type=uuid_value, location="args") parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args() args = parser.parse_args()
conversation = ( conversation = db.session.query(Conversation).filter(
db.session.query(Conversation) Conversation.id == args['conversation_id'],
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) Conversation.app_id == app_model.id
.first() ).first()
)
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
if args["first_id"]: if args['first_id']:
first_message = ( first_message = db.session.query(Message) \
db.session.query(Message) .filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first()
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first()
)
if not first_message: if not first_message:
raise NotFound("First message not found") raise NotFound("First message not found")
history_messages = ( history_messages = db.session.query(Message).filter(
db.session.query(Message) Message.conversation_id == conversation.id,
.filter( Message.created_at < first_message.created_at,
Message.conversation_id == conversation.id, Message.id != first_message.id
Message.created_at < first_message.created_at, ) \
Message.id != first_message.id, .order_by(Message.created_at.desc()).limit(args['limit']).all()
)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
)
else: else:
history_messages = ( history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
db.session.query(Message) .order_by(Message.created_at.desc()).limit(args['limit']).all()
.filter(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
)
has_more = False has_more = False
if len(history_messages) == args["limit"]: if len(history_messages) == args['limit']:
current_page_first_message = history_messages[-1] current_page_first_message = history_messages[-1]
rest_count = ( rest_count = db.session.query(Message).filter(
db.session.query(Message) Message.conversation_id == conversation.id,
.filter( Message.created_at < current_page_first_message.created_at,
Message.conversation_id == conversation.id, Message.id != current_page_first_message.id
Message.created_at < current_page_first_message.created_at, ).count()
Message.id != current_page_first_message.id,
)
.count()
)
if rest_count > 0: if rest_count > 0:
has_more = True has_more = True
history_messages = list(reversed(history_messages)) history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) return InfiniteScrollPagination(
data=history_messages,
limit=args['limit'],
has_more=has_more
)
class MessageFeedbackApi(Resource): class MessageFeedbackApi(Resource):
@ -117,46 +103,49 @@ class MessageFeedbackApi(Resource):
@get_app_model @get_app_model
def post(self, app_model): def post(self, app_model):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("message_id", required=True, type=uuid_value, location="json") parser.add_argument('message_id', required=True, type=uuid_value, location='json')
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
args = parser.parse_args() args = parser.parse_args()
message_id = str(args["message_id"]) message_id = str(args['message_id'])
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id
).first()
if not message: if not message:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
feedback = message.admin_feedback feedback = message.admin_feedback
if not args["rating"] and feedback: if not args['rating'] and feedback:
db.session.delete(feedback) db.session.delete(feedback)
elif args["rating"] and feedback: elif args['rating'] and feedback:
feedback.rating = args["rating"] feedback.rating = args['rating']
elif not args["rating"] and not feedback: elif not args['rating'] and not feedback:
raise ValueError("rating cannot be None when feedback not exists") raise ValueError('rating cannot be None when feedback not exists')
else: else:
feedback = MessageFeedback( feedback = MessageFeedback(
app_id=app_model.id, app_id=app_model.id,
conversation_id=message.conversation_id, conversation_id=message.conversation_id,
message_id=message.id, message_id=message.id,
rating=args["rating"], rating=args['rating'],
from_source="admin", from_source='admin',
from_account_id=current_user.id, from_account_id=current_user.id
) )
db.session.add(feedback) db.session.add(feedback)
db.session.commit() db.session.commit()
return {"result": "success"} return {'result': 'success'}
class MessageAnnotationApi(Resource): class MessageAnnotationApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("annotation") @cloud_edition_billing_resource_check('annotation')
@get_app_model @get_app_model
@marshal_with(annotation_fields) @marshal_with(annotation_fields)
def post(self, app_model): def post(self, app_model):
@ -164,10 +153,10 @@ class MessageAnnotationApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("message_id", required=False, type=uuid_value, location="json") parser.add_argument('message_id', required=False, type=uuid_value, location='json')
parser.add_argument("question", required=True, type=str, location="json") parser.add_argument('question', required=True, type=str, location='json')
parser.add_argument("answer", required=True, type=str, location="json") parser.add_argument('answer', required=True, type=str, location='json')
parser.add_argument("annotation_reply", required=False, type=dict, location="json") parser.add_argument('annotation_reply', required=False, type=dict, location='json')
args = parser.parse_args() args = parser.parse_args()
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
@ -180,9 +169,11 @@ class MessageAnnotationCountApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
def get(self, app_model): def get(self, app_model):
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() count = db.session.query(MessageAnnotation).filter(
MessageAnnotation.app_id == app_model.id
).count()
return {"count": count} return {'count': count}
class MessageSuggestedQuestionApi(Resource): class MessageSuggestedQuestionApi(Resource):
@ -195,7 +186,10 @@ class MessageSuggestedQuestionApi(Resource):
try: try:
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER app_model=app_model,
message_id=message_id,
user=current_user,
invoke_from=InvokeFrom.DEBUGGER
) )
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message not found") raise NotFound("Message not found")
@ -215,7 +209,7 @@ class MessageSuggestedQuestionApi(Resource):
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
return {"data": questions} return {'data': questions}
class MessageApi(Resource): class MessageApi(Resource):
@ -227,7 +221,10 @@ class MessageApi(Resource):
def get(self, app_model, message_id): def get(self, app_model, message_id):
message_id = str(message_id) message_id = str(message_id)
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id
).first()
if not message: if not message:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -235,9 +232,9 @@ class MessageApi(Resource):
return message return message
api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions") api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions')
api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages") api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages')
api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks") api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks')
api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations") api.add_resource(MessageAnnotationApi, '/apps/<uuid:app_id>/annotations')
api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count") api.add_resource(MessageAnnotationCountApi, '/apps/<uuid:app_id>/annotations/count')
api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message") api.add_resource(MessageApi, '/apps/<uuid:app_id>/messages/<uuid:message_id>', endpoint='console_message')

View File

@ -19,35 +19,37 @@ from services.app_model_config_service import AppModelConfigService
class ModelConfigResource(Resource): class ModelConfigResource(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model): def post(self, app_model):
"""Modify app model config""" """Modify app model config"""
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) tenant_id=current_user.current_tenant_id,
config=request.json,
app_mode=AppMode.value_of(app_model.mode)
) )
new_app_model_config = AppModelConfig( new_app_model_config = AppModelConfig(
app_id=app_model.id, app_id=app_model.id,
created_by=current_user.id,
updated_by=current_user.id,
) )
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config # get original app model config
original_app_model_config: AppModelConfig = ( original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() AppModelConfig.id == app_model.app_model_config_id
) ).first()
agent_mode = original_app_model_config.agent_mode_dict agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input # decrypt agent tool parameters if it's secret-input
parameter_map = {} parameter_map = {}
masked_parameter_map = {} masked_parameter_map = {}
tool_map = {} tool_map = {}
for tool in agent_mode.get("tools") or []: for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3: if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue continue
@ -64,7 +66,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id, provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type, provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app_model.id}", identity_id=f'AGENT.{app_model.id}'
) )
except Exception as e: except Exception as e:
continue continue
@ -77,18 +79,18 @@ class ModelConfigResource(Resource):
parameters = {} parameters = {}
masked_parameter = {} masked_parameter = {}
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters parameter_map[key] = parameters
tool_map[key] = tool_runtime tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input # encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []: for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool) agent_tool_entity = AgentToolEntity(**tool)
# get tool # get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map: if key in tool_map:
tool_runtime = tool_map[key] tool_runtime = tool_map[key]
else: else:
@ -106,7 +108,7 @@ class ModelConfigResource(Resource):
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id, provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type, provider_type=agent_tool_entity.provider_type,
identity_id=f"AGENT.{app_model.id}", identity_id=f'AGENT.{app_model.id}'
) )
manager.delete_tool_parameters_cache() manager.delete_tool_parameters_cache()
@ -114,17 +116,15 @@ class ModelConfigResource(Resource):
if agent_tool_entity.tool_parameters: if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map: if key not in masked_parameter_map:
continue continue
for masked_key, masked_value in masked_parameter_map[key].items(): for masked_key, masked_value in masked_parameter_map[key].items():
if ( if masked_key in agent_tool_entity.tool_parameters and \
masked_key in agent_tool_entity.tool_parameters agent_tool_entity.tool_parameters[masked_key] == masked_value:
and agent_tool_entity.tool_parameters[masked_key] == masked_value
):
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key) agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
# encrypt parameters # encrypt parameters
if agent_tool_entity.tool_parameters: if agent_tool_entity.tool_parameters:
tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config # update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode) new_app_model_config.agent_mode = json.dumps(agent_mode)
@ -135,9 +135,12 @@ class ModelConfigResource(Resource):
app_model.app_model_config_id = new_app_model_config.id app_model.app_model_config_id = new_app_model_config.id
db.session.commit() db.session.commit()
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) app_model_config_was_updated.send(
app_model,
app_model_config=new_app_model_config
)
return {"result": "success"} return {'result': 'success'}
api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config") api.add_resource(ModelConfigResource, '/apps/<uuid:app_id>/model-config')

View File

@ -18,11 +18,13 @@ class TraceAppConfigApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_id): def get(self, app_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args") parser.add_argument('tracing_provider', type=str, required=True, location='args')
args = parser.parse_args() args = parser.parse_args()
try: try:
trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) trace_config = OpsService.get_tracing_app_config(
app_id=app_id, tracing_provider=args['tracing_provider']
)
if not trace_config: if not trace_config:
return {"has_not_configured": True} return {"has_not_configured": True}
return trace_config return trace_config
@ -35,17 +37,19 @@ class TraceAppConfigApi(Resource):
def post(self, app_id): def post(self, app_id):
"""Create a new trace app configuration""" """Create a new trace app configuration"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json") parser.add_argument('tracing_provider', type=str, required=True, location='json')
parser.add_argument("tracing_config", type=dict, required=True, location="json") parser.add_argument('tracing_config', type=dict, required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
result = OpsService.create_tracing_app_config( result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] app_id=app_id,
tracing_provider=args['tracing_provider'],
tracing_config=args['tracing_config']
) )
if not result: if not result:
raise TracingConfigIsExist() raise TracingConfigIsExist()
if result.get("error"): if result.get('error'):
raise TracingConfigCheckError() raise TracingConfigCheckError()
return result return result
except Exception as e: except Exception as e:
@ -57,13 +61,15 @@ class TraceAppConfigApi(Resource):
def patch(self, app_id): def patch(self, app_id):
"""Update an existing trace app configuration""" """Update an existing trace app configuration"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="json") parser.add_argument('tracing_provider', type=str, required=True, location='json')
parser.add_argument("tracing_config", type=dict, required=True, location="json") parser.add_argument('tracing_config', type=dict, required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
result = OpsService.update_tracing_app_config( result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] app_id=app_id,
tracing_provider=args['tracing_provider'],
tracing_config=args['tracing_config']
) )
if not result: if not result:
raise TracingConfigNotExist() raise TracingConfigNotExist()
@ -77,11 +83,14 @@ class TraceAppConfigApi(Resource):
def delete(self, app_id): def delete(self, app_id):
"""Delete an existing trace app configuration""" """Delete an existing trace app configuration"""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("tracing_provider", type=str, required=True, location="args") parser.add_argument('tracing_provider', type=str, required=True, location='args')
args = parser.parse_args() args = parser.parse_args()
try: try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) result = OpsService.delete_tracing_app_config(
app_id=app_id,
tracing_provider=args['tracing_provider']
)
if not result: if not result:
raise TracingConfigNotExist() raise TracingConfigNotExist()
return {"result": "success"} return {"result": "success"}
@ -89,4 +98,4 @@ class TraceAppConfigApi(Resource):
raise e raise e
api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config") api.add_resource(TraceAppConfigApi, '/apps/<uuid:app_id>/trace-config')

View File

@ -1,5 +1,3 @@
from datetime import datetime, timezone
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -17,24 +15,23 @@ from models.model import Site
def parse_app_site_args(): def parse_app_site_args():
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("title", type=str, required=False, location="json") parser.add_argument('title', type=str, required=False, location='json')
parser.add_argument("icon_type", type=str, required=False, location="json") parser.add_argument('icon_type', type=str, required=False, location='json')
parser.add_argument("icon", type=str, required=False, location="json") parser.add_argument('icon', type=str, required=False, location='json')
parser.add_argument("icon_background", type=str, required=False, location="json") parser.add_argument('icon_background', type=str, required=False, location='json')
parser.add_argument("description", type=str, required=False, location="json") parser.add_argument('description', type=str, required=False, location='json')
parser.add_argument("default_language", type=supported_language, required=False, location="json") parser.add_argument('default_language', type=supported_language, required=False, location='json')
parser.add_argument("chat_color_theme", type=str, required=False, location="json") parser.add_argument('chat_color_theme', type=str, required=False, location='json')
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json')
parser.add_argument("customize_domain", type=str, required=False, location="json") parser.add_argument('customize_domain', type=str, required=False, location='json')
parser.add_argument("copyright", type=str, required=False, location="json") parser.add_argument('copyright', type=str, required=False, location='json')
parser.add_argument("privacy_policy", type=str, required=False, location="json") parser.add_argument('privacy_policy', type=str, required=False, location='json')
parser.add_argument("custom_disclaimer", type=str, required=False, location="json") parser.add_argument('custom_disclaimer', type=str, required=False, location='json')
parser.add_argument( parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'],
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" required=False,
) location='json')
parser.add_argument("prompt_public", type=bool, required=False, location="json") parser.add_argument('prompt_public', type=bool, required=False, location='json')
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") parser.add_argument('show_workflow_steps', type=bool, required=False, location='json')
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
return parser.parse_args() return parser.parse_args()
@ -51,38 +48,38 @@ class AppSite(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() site = db.session.query(Site). \
filter(Site.app_id == app_model.id). \
one_or_404()
for attr_name in [ for attr_name in [
"title", 'title',
"icon_type", 'icon_type',
"icon", 'icon',
"icon_background", 'icon_background',
"description", 'description',
"default_language", 'default_language',
"chat_color_theme", 'chat_color_theme',
"chat_color_theme_inverted", 'chat_color_theme_inverted',
"customize_domain", 'customize_domain',
"copyright", 'copyright',
"privacy_policy", 'privacy_policy',
"custom_disclaimer", 'custom_disclaimer',
"customize_token_strategy", 'customize_token_strategy',
"prompt_public", 'prompt_public',
"show_workflow_steps", 'show_workflow_steps'
"use_icon_as_answer_icon",
]: ]:
value = args.get(attr_name) value = args.get(attr_name)
if value is not None: if value is not None:
setattr(site, attr_name, value) setattr(site, attr_name, value)
site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site
class AppSiteAccessTokenReset(Resource): class AppSiteAccessTokenReset(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -99,12 +96,10 @@ class AppSiteAccessTokenReset(Resource):
raise NotFound raise NotFound
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site
api.add_resource(AppSite, "/apps/<uuid:app_id>/site") api.add_resource(AppSite, '/apps/<uuid:app_id>/site')
api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset") api.add_resource(AppSiteAccessTokenReset, '/apps/<uuid:app_id>/site/access-token-reset')

View File

@ -11,69 +11,13 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import DatetimeString from libs.helper import datetime_string
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models.model import AppMode
class DailyMessageStatistic(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
def get(self, app_model):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
sql_query = """SELECT
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(*) AS message_count
FROM
messages
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start"
arg_dict["start"] = start_datetime_utc
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end"
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date"
response_data = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs:
response_data.append({"date": str(i.date), "message_count": i.message_count})
return jsonify({"data": response_data})
class DailyConversationStatistic(Resource): class DailyConversationStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -82,55 +26,58 @@ class DailyConversationStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = '''
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count
COUNT(DISTINCT messages.conversation_id) AS conversation_count FROM messages where app_id = :app_id
FROM '''
messages arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += ' and created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += ' and created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += ' GROUP BY date order by date'
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) response_data.append({
'date': str(i.date),
'conversation_count': i.conversation_count
})
return jsonify({"data": response_data}) return jsonify({
'data': response_data
})
class DailyTerminalsStatistic(Resource): class DailyTerminalsStatistic(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -139,52 +86,54 @@ class DailyTerminalsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = '''
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count FROM messages where app_id = :app_id
FROM '''
messages arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
WHERE
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += ' and created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += ' and created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += ' GROUP BY date order by date'
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) response_data.append({
'date': str(i.date),
'terminal_count': i.terminal_count
})
return jsonify({"data": response_data}) return jsonify({
'data': response_data
})
class DailyTokenCostStatistic(Resource): class DailyTokenCostStatistic(Resource):
@ -196,55 +145,58 @@ class DailyTokenCostStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = '''
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count, (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count,
SUM(total_price) AS total_price sum(total_price) as total_price
FROM FROM messages where app_id = :app_id
messages '''
WHERE arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += ' and created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += ' and created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += ' GROUP BY date order by date'
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append( response_data.append({
{"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} 'date': str(i.date),
) 'token_count': i.token_count,
'total_price': i.total_price,
'currency': 'USD'
})
return jsonify({"data": response_data}) return jsonify({
'data': response_data
})
class AverageSessionInteractionStatistic(Resource): class AverageSessionInteractionStatistic(Resource):
@ -256,72 +208,60 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(subquery.message_count) AS interactions
AVG(subquery.message_count) AS interactions FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count
FROM FROM conversations c
( JOIN messages m ON c.id = m.conversation_id
SELECT WHERE c.override_model_configs IS NULL AND c.app_id = :app_id"""
m.conversation_id, arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
COUNT(m.id) AS message_count
FROM
conversations c
JOIN
messages m
ON c.id = m.conversation_id
WHERE
c.override_model_configs IS NULL
AND c.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND c.created_at >= :start" sql_query += ' and c.created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND c.created_at < :end" sql_query += ' and c.created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += """ sql_query += """
GROUP BY m.conversation_id GROUP BY m.conversation_id) subquery
) subquery LEFT JOIN conversations c on c.id=subquery.conversation_id
LEFT JOIN GROUP BY date
conversations c ORDER BY date"""
ON c.id = subquery.conversation_id
GROUP BY
date
ORDER BY
date"""
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append( response_data.append({
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} 'date': str(i.date),
) 'interactions': float(i.interactions.quantize(Decimal('0.01')))
})
return jsonify({"data": response_data}) return jsonify({
'data': response_data
})
class UserSatisfactionRateStatistic(Resource): class UserSatisfactionRateStatistic(Resource):
@ -333,61 +273,57 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = '''
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
COUNT(m.id) AS message_count, COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count
COUNT(mf.id) AS feedback_count FROM messages m
FROM LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like'
messages m WHERE m.app_id = :app_id
LEFT JOIN '''
message_feedbacks mf arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
ON mf.message_id=m.id AND mf.rating='like'
WHERE
m.app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND m.created_at >= :start" sql_query += ' and m.created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND m.created_at < :end" sql_query += ' and m.created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += ' GROUP BY date order by date'
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append( response_data.append({
{ 'date': str(i.date),
"date": str(i.date), 'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2),
"rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), })
}
)
return jsonify({"data": response_data}) return jsonify({
'data': response_data
})
class AverageResponseTimeStatistic(Resource): class AverageResponseTimeStatistic(Resource):
@ -399,52 +335,56 @@ class AverageResponseTimeStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = '''
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
AVG(provider_response_latency) AS latency AVG(provider_response_latency) as latency
FROM FROM messages
messages WHERE app_id = :app_id
WHERE '''
app_id = :app_id""" arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += ' and created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += ' and created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += ' GROUP BY date order by date'
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) response_data.append({
'date': str(i.date),
'latency': round(i.latency * 1000, 4)
})
return jsonify({"data": response_data}) return jsonify({
'data': response_data
})
class TokensPerSecondStatistic(Resource): class TokensPerSecondStatistic(Resource):
@ -456,62 +396,63 @@ class TokensPerSecondStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, CASE
CASE
WHEN SUM(provider_response_latency) = 0 THEN 0 WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
END as tokens_per_second END as tokens_per_second
FROM FROM messages
messages WHERE app_id = :app_id'''
WHERE arg_dict = {'tz': account.timezone, 'app_id': app_model.id}
app_id = :app_id"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += ' and created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += ' and created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += ' GROUP BY date order by date'
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) response_data.append({
'date': str(i.date),
'tps': round(i.tokens_per_second, 4)
})
return jsonify({"data": response_data}) return jsonify({
'data': response_data
})
api.add_resource(DailyMessageStatistic, "/apps/<uuid:app_id>/statistics/daily-messages") api.add_resource(DailyConversationStatistic, '/apps/<uuid:app_id>/statistics/daily-conversations')
api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations") api.add_resource(DailyTerminalsStatistic, '/apps/<uuid:app_id>/statistics/daily-end-users')
api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users") api.add_resource(DailyTokenCostStatistic, '/apps/<uuid:app_id>/statistics/token-costs')
api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs") api.add_resource(AverageSessionInteractionStatistic, '/apps/<uuid:app_id>/statistics/average-session-interactions')
api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions") api.add_resource(UserSatisfactionRateStatistic, '/apps/<uuid:app_id>/statistics/user-satisfaction-rate')
api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate") api.add_resource(AverageResponseTimeStatistic, '/apps/<uuid:app_id>/statistics/average-response-time')
api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time") api.add_resource(TokensPerSecondStatistic, '/apps/<uuid:app_id>/statistics/tokens-per-second')
api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second")

View File

@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
content_type = request.headers.get('Content-Type', '')
content_type = request.headers.get("Content-Type", "") if 'application/json' in content_type:
if "application/json" in content_type:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument("features", type=dict, required=True, nullable=False, location="json") parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument('hash', type=str, required=False, location='json')
# TODO: set this to required=True after frontend is updated # TODO: set this to required=True after frontend is updated
parser.add_argument("environment_variables", type=list, required=False, location="json") parser.add_argument('environment_variables', type=list, required=False, location='json')
parser.add_argument("conversation_variables", type=list, required=False, location="json") parser.add_argument('conversation_variables', type=list, required=False, location='json')
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif 'text/plain' in content_type:
try: try:
data = json.loads(request.data.decode("utf-8")) data = json.loads(request.data.decode('utf-8'))
if "graph" not in data or "features" not in data: if 'graph' not in data or 'features' not in data:
raise ValueError("graph or features not found in data") raise ValueError('graph or features not found in data')
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict):
raise ValueError("graph or features is not a dict") raise ValueError('graph or features is not a dict')
args = { args = {
"graph": data.get("graph"), 'graph': data.get('graph'),
"features": data.get("features"), 'features': data.get('features'),
"hash": data.get("hash"), 'hash': data.get('hash'),
"environment_variables": data.get("environment_variables"), 'environment_variables': data.get('environment_variables'),
"conversation_variables": data.get("conversation_variables"), 'conversation_variables': data.get('conversation_variables'),
} }
except json.JSONDecodeError: except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400 return {'message': 'Invalid JSON data'}, 400
else: else:
abort(415) abort(415)
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
environment_variables_list = args.get("environment_variables") or [] environment_variables_list = args.get('environment_variables') or []
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
conversation_variables_list = args.get("conversation_variables") or [] conversation_variables_list = args.get('conversation_variables') or []
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
workflow = workflow_service.sync_draft_workflow( workflow = workflow_service.sync_draft_workflow(
app_model=app_model, app_model=app_model,
graph=args["graph"], graph=args['graph'],
features=args["features"], features=args['features'],
unique_hash=args.get("hash"), unique_hash=args.get('hash'),
account=current_user, account=current_user,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource):
return { return {
"result": "success", "result": "success",
"hash": workflow.unique_hash, "hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
} }
@ -138,11 +138,13 @@ class DraftWorkflowImportApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json") parser.add_argument('data', type=str, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
workflow = AppDslService.import_and_overwrite_workflow( workflow = AppDslService.import_and_overwrite_workflow(
app_model=app_model, data=args["data"], account=current_user app_model=app_model,
data=args['data'],
account=current_user
) )
return workflow return workflow
@ -160,17 +162,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument('inputs', type=dict, location='json')
parser.add_argument("query", type=str, required=True, location="json", default="") parser.add_argument('query', type=str, required=True, location='json', default='')
parser.add_argument("files", type=list, location="json") parser.add_argument('files', type=list, location='json')
parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument('conversation_id', type=uuid_value, location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -184,7 +190,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
class AdvancedChatDraftRunIterationNodeApi(Resource): class AdvancedChatDraftRunIterationNodeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -197,14 +202,18 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument('inputs', type=dict, location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_iteration( response = AppGenerateService.generate_single_iteration(
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -218,7 +227,6 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
class WorkflowDraftRunIterationNodeApi(Resource): class WorkflowDraftRunIterationNodeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -231,14 +239,18 @@ class WorkflowDraftRunIterationNodeApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json") parser.add_argument('inputs', type=dict, location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_iteration( response = AppGenerateService.generate_single_iteration(
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True app_model=app_model,
user=current_user,
node_id=node_id,
args=args,
streaming=True
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -252,7 +264,6 @@ class WorkflowDraftRunIterationNodeApi(Resource):
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
class DraftWorkflowRunApi(Resource): class DraftWorkflowRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -265,15 +276,19 @@ class DraftWorkflowRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument('files', type=list, required=False, location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -296,10 +311,12 @@ class WorkflowTaskStopApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"} return {
"result": "success"
}
class DraftWorkflowNodeRunApi(Resource): class DraftWorkflowNodeRunApi(Resource):
@ -315,20 +332,24 @@ class DraftWorkflowNodeRunApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node( workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user app_model=app_model,
node_id=node_id,
user_inputs=args.get('inputs'),
account=current_user
) )
return workflow_node_execution return workflow_node_execution
class PublishedWorkflowApi(Resource): class PublishedWorkflowApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -341,7 +362,7 @@ class PublishedWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# fetch published workflow by app_model # fetch published workflow by app_model
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app_model) workflow = workflow_service.get_published_workflow(app_model=app_model)
@ -360,11 +381,14 @@ class PublishedWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} return {
"result": "success",
"created_at": TimestampField().format(workflow.created_at)
}
class DefaultBlockConfigsApi(Resource): class DefaultBlockConfigsApi(Resource):
@ -379,7 +403,7 @@ class DefaultBlockConfigsApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# Get default block configs # Get default block configs
workflow_service = WorkflowService() workflow_service = WorkflowService()
return workflow_service.get_default_block_configs() return workflow_service.get_default_block_configs()
@ -397,21 +421,24 @@ class DefaultBlockConfigApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args") parser.add_argument('q', type=str, location='args')
args = parser.parse_args() args = parser.parse_args()
filters = None filters = None
if args.get("q"): if args.get('q'):
try: try:
filters = json.loads(args.get("q")) filters = json.loads(args.get('q'))
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError("Invalid filters") raise ValueError('Invalid filters')
# Get default block configs # Get default block configs
workflow_service = WorkflowService() workflow_service = WorkflowService()
return workflow_service.get_default_block_config(node_type=block_type, filters=filters) return workflow_service.get_default_block_config(
node_type=block_type,
filters=filters
)
class ConvertToWorkflowApi(Resource): class ConvertToWorkflowApi(Resource):
@ -428,43 +455,41 @@ class ConvertToWorkflowApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
if request.data: if request.data:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json") parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json')
parser.add_argument("icon", type=str, required=False, nullable=True, location="json") parser.add_argument('icon', type=str, required=False, nullable=True, location='json')
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
else: else:
args = {} args = {}
# convert to workflow mode # convert to workflow mode
workflow_service = WorkflowService() workflow_service = WorkflowService()
new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args) new_app_model = workflow_service.convert_to_workflow(
app_model=app_model,
account=current_user,
args=args
)
# return app id # return app id
return { return {
"new_app_id": new_app_model.id, 'new_app_id': new_app_model.id,
} }
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft") api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft')
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import") api.add_resource(DraftWorkflowImportApi, '/apps/<uuid:app_id>/workflows/draft/import')
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run')
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run") api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run')
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop')
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run") api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run')
api.add_resource( api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run')
AdvancedChatDraftRunIterationNodeApi, api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run')
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run", api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish')
) api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs')
api.add_resource( api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs'
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run" '/<string:block_type>')
) api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow')
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
api.add_resource(
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>"
)
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")

View File

@ -22,19 +22,20 @@ class WorkflowAppLogApi(Resource):
Get workflow app logs Get workflow app logs
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument('keyword', type=str, location='args')
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args')
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument('page', type=int_range(1, 99999), default=1, location='args')
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") parser.add_argument('limit', type=int_range(1, 100), default=20, location='args')
args = parser.parse_args() args = parser.parse_args()
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
app_model=app_model, args=args app_model=app_model,
args=args
) )
return workflow_app_log_pagination return workflow_app_log_pagination
api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs") api.add_resource(WorkflowAppLogApi, '/apps/<uuid:app_id>/workflow-app-logs')

View File

@ -28,12 +28,15 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
Get advanced chat app workflow run list Get advanced chat app workflow run list
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args() args = parser.parse_args()
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
app_model=app_model,
args=args
)
return result return result
@ -49,12 +52,15 @@ class WorkflowRunListApi(Resource):
Get workflow run list Get workflow run list
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args() args = parser.parse_args()
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) result = workflow_run_service.get_paginate_workflow_runs(
app_model=app_model,
args=args
)
return result return result
@ -92,10 +98,12 @@ class WorkflowRunNodeExecutionListApi(Resource):
workflow_run_service = WorkflowRunService() workflow_run_service = WorkflowRunService()
node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id)
return {"data": node_executions} return {
'data': node_executions
}
api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps/<uuid:app_id>/advanced-chat/workflow-runs") api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps/<uuid:app_id>/advanced-chat/workflow-runs')
api.add_resource(WorkflowRunListApi, "/apps/<uuid:app_id>/workflow-runs") api.add_resource(WorkflowRunListApi, '/apps/<uuid:app_id>/workflow-runs')
api.add_resource(WorkflowRunDetailApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>") api.add_resource(WorkflowRunDetailApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>')
api.add_resource(WorkflowRunNodeExecutionListApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions") api.add_resource(WorkflowRunNodeExecutionListApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions')

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import DatetimeString from libs.helper import datetime_string
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom from models.workflow import WorkflowRunTriggeredFrom
@ -26,58 +26,56 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = '''
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs
COUNT(id) AS runs FROM workflow_runs
FROM WHERE app_id = :app_id
workflow_runs AND triggered_from = :triggered_from
WHERE '''
app_id = :app_id arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += ' and created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += ' and created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += ' GROUP BY date order by date'
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append({"date": str(i.date), "runs": i.runs}) response_data.append({
'date': str(i.date),
return jsonify({"data": response_data}) 'runs': i.runs
})
return jsonify({
'data': response_data
})
class WorkflowDailyTerminalsStatistic(Resource): class WorkflowDailyTerminalsStatistic(Resource):
@setup_required @setup_required
@ -88,58 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = '''
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count FROM workflow_runs
FROM WHERE app_id = :app_id
workflow_runs AND triggered_from = :triggered_from
WHERE '''
app_id = :app_id arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
AND triggered_from = :triggered_from"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += ' and created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += ' and created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += ' GROUP BY date order by date'
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) response_data.append({
'date': str(i.date),
return jsonify({"data": response_data}) 'terminal_count': i.terminal_count
})
return jsonify({
'data': response_data
})
class WorkflowDailyTokenCostStatistic(Resource): class WorkflowDailyTokenCostStatistic(Resource):
@setup_required @setup_required
@ -150,63 +146,58 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = '''
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SELECT
SUM(workflow_runs.total_tokens) AS token_count date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
FROM SUM(workflow_runs.total_tokens) as token_count
workflow_runs FROM workflow_runs
WHERE WHERE app_id = :app_id
app_id = :app_id AND triggered_from = :triggered_from
AND triggered_from = :triggered_from""" '''
arg_dict = { arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at >= :start" sql_query += ' and created_at >= :start'
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query += " AND created_at < :end" sql_query += ' and created_at < :end'
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" sql_query += ' GROUP BY date order by date'
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append( response_data.append({
{ 'date': str(i.date),
"date": str(i.date), 'token_count': i.token_count,
"token_count": i.token_count, })
}
)
return jsonify({"data": response_data})
return jsonify({
'data': response_data
})
class WorkflowAverageAppInteractionStatistic(Resource): class WorkflowAverageAppInteractionStatistic(Resource):
@setup_required @setup_required
@ -217,79 +208,71 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args')
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT sql_query = """
AVG(sub.interactions) AS interactions, SELECT
sub.date AVG(sub.interactions) as interactions,
FROM sub.date
( FROM
SELECT (SELECT
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
c.created_by, c.created_by,
COUNT(c.id) AS interactions COUNT(c.id) AS interactions
FROM FROM workflow_runs c
workflow_runs c WHERE c.app_id = :app_id
WHERE AND c.triggered_from = :triggered_from
c.app_id = :app_id {{start}}
AND c.triggered_from = :triggered_from {{end}}
{{start}} GROUP BY date, c.created_by) sub
{{end}} GROUP BY sub.date
GROUP BY """
date, c.created_by arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
) sub
GROUP BY
sub.date"""
arg_dict = {
"tz": account.timezone,
"app_id": app_model.id,
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
}
timezone = pytz.timezone(account.timezone) timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc utc_timezone = pytz.utc
if args["start"]: if args['start']:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M')
start_datetime = start_datetime.replace(second=0) start_datetime = start_datetime.replace(second=0)
start_datetime_timezone = timezone.localize(start_datetime) start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start") sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start')
arg_dict["start"] = start_datetime_utc arg_dict['start'] = start_datetime_utc
else: else:
sql_query = sql_query.replace("{{start}}", "") sql_query = sql_query.replace('{{start}}', '')
if args["end"]: if args['end']:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M')
end_datetime = end_datetime.replace(second=0) end_datetime = end_datetime.replace(second=0)
end_datetime_timezone = timezone.localize(end_datetime) end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end") sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end')
arg_dict["end"] = end_datetime_utc arg_dict['end'] = end_datetime_utc
else: else:
sql_query = sql_query.replace("{{end}}", "") sql_query = sql_query.replace('{{end}}', '')
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(sql_query), arg_dict) rs = conn.execute(db.text(sql_query), arg_dict)
for i in rs: for i in rs:
response_data.append( response_data.append({
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} 'date': str(i.date),
) 'interactions': float(i.interactions.quantize(Decimal('0.01')))
})
return jsonify({"data": response_data}) return jsonify({
'data': response_data
})
api.add_resource(WorkflowDailyRunsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-conversations')
api.add_resource(WorkflowDailyRunsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-conversations") api.add_resource(WorkflowDailyTerminalsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-terminals')
api.add_resource(WorkflowDailyTerminalsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-terminals") api.add_resource(WorkflowDailyTokenCostStatistic, '/apps/<uuid:app_id>/workflow/statistics/token-costs')
api.add_resource(WorkflowDailyTokenCostStatistic, "/apps/<uuid:app_id>/workflow/statistics/token-costs") api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps/<uuid:app_id>/workflow/statistics/average-app-interactions')
api.add_resource(
WorkflowAverageAppInteractionStatistic, "/apps/<uuid:app_id>/workflow/statistics/average-app-interactions"
)

View File

@ -8,23 +8,24 @@ from libs.login import current_user
from models.model import App, AppMode from models.model import App, AppMode
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): def get_app_model(view: Optional[Callable] = None, *,
mode: Union[AppMode, list[AppMode]] = None):
def decorator(view_func): def decorator(view_func):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args, **kwargs): def decorated_view(*args, **kwargs):
if not kwargs.get("app_id"): if not kwargs.get('app_id'):
raise ValueError("missing app_id in path parameters") raise ValueError('missing app_id in path parameters')
app_id = kwargs.get("app_id") app_id = kwargs.get('app_id')
app_id = str(app_id) app_id = str(app_id)
del kwargs["app_id"] del kwargs['app_id']
app_model = ( app_model = db.session.query(App).filter(
db.session.query(App) App.id == app_id,
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") App.tenant_id == current_user.current_tenant_id,
.first() App.status == 'normal'
) ).first()
if not app_model: if not app_model:
raise AppNotFoundError() raise AppNotFoundError()
@ -43,10 +44,9 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
mode_values = {m.value for m in modes} mode_values = {m.value for m in modes}
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
kwargs["app_model"] = app_model kwargs['app_model'] = app_model
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
return decorated_view return decorated_view
if view is None: if view is None:

View File

@ -8,7 +8,7 @@ from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import StrLen, email, timezone from libs.helper import email, str_len, timezone
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models.account import AccountStatus from models.account import AccountStatus
from services.account_service import RegisterService from services.account_service import RegisterService
@ -17,61 +17,60 @@ from services.account_service import RegisterService
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args')
parser.add_argument("email", type=email, required=False, nullable=True, location="args") parser.add_argument('email', type=email, required=False, nullable=True, location='args')
parser.add_argument("token", type=str, required=True, nullable=False, location="args") parser.add_argument('token', type=str, required=True, nullable=False, location='args')
args = parser.parse_args() args = parser.parse_args()
workspaceId = args["workspace_id"] workspaceId = args['workspace_id']
reg_email = args["email"] reg_email = args['email']
token = args["token"] token = args['token']
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None} return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None}
class ActivateApi(Resource): class ActivateApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json')
parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument('email', type=email, required=False, nullable=True, location='json')
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json')
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument( parser.add_argument('interface_language', type=supported_language, required=True, nullable=False,
"interface_language", type=supported_language, required=True, nullable=False, location="json" location='json')
) parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json')
parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token'])
if invitation is None: if invitation is None:
raise AlreadyActivateError() raise AlreadyActivateError()
RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) RegisterService.revoke_token(args['workspace_id'], args['email'], args['token'])
account = invitation["account"] account = invitation['account']
account.name = args["name"] account.name = args['name']
# generate password salt # generate password salt
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode() base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt # encrypt password with salt
password_hashed = hash_password(args["password"], salt) password_hashed = hash_password(args['password'], salt)
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
account.interface_language = args["interface_language"] account.interface_language = args['interface_language']
account.timezone = args["timezone"] account.timezone = args['timezone']
account.interface_theme = "light" account.interface_theme = 'light'
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success"} return {'result': 'success'}
api.add_resource(ActivateCheckApi, "/activate/check") api.add_resource(ActivateCheckApi, '/activate/check')
api.add_resource(ActivateApi, "/activate") api.add_resource(ActivateApi, '/activate')

View File

@ -19,19 +19,18 @@ class ApiKeyAuthDataSource(Resource):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
if data_source_api_key_bindings: if data_source_api_key_bindings:
return { return {
"sources": [ 'sources': [{
{ 'id': data_source_api_key_binding.id,
"id": data_source_api_key_binding.id, 'category': data_source_api_key_binding.category,
"category": data_source_api_key_binding.category, 'provider': data_source_api_key_binding.provider,
"provider": data_source_api_key_binding.provider, 'disabled': data_source_api_key_binding.disabled,
"disabled": data_source_api_key_binding.disabled, 'created_at': int(data_source_api_key_binding.created_at.timestamp()),
"created_at": int(data_source_api_key_binding.created_at.timestamp()), 'updated_at': int(data_source_api_key_binding.updated_at.timestamp()),
"updated_at": int(data_source_api_key_binding.updated_at.timestamp()), }
} for data_source_api_key_binding in
for data_source_api_key_binding in data_source_api_key_bindings data_source_api_key_bindings]
]
} }
return {"sources": []} return {'sources': []}
class ApiKeyAuthDataSourceBinding(Resource): class ApiKeyAuthDataSourceBinding(Resource):
@ -43,16 +42,16 @@ class ApiKeyAuthDataSourceBinding(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("category", type=str, required=True, nullable=False, location="json") parser.add_argument('category', type=str, required=True, nullable=False, location='json')
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
ApiKeyAuthService.validate_api_key_auth_args(args) ApiKeyAuthService.validate_api_key_auth_args(args)
try: try:
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
except Exception as e: except Exception as e:
raise ApiKeyAuthFailedError(str(e)) raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200 return {'result': 'success'}, 200
class ApiKeyAuthDataSourceBindingDelete(Resource): class ApiKeyAuthDataSourceBindingDelete(Resource):
@ -66,9 +65,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
return {"result": "success"}, 200 return {'result': 'success'}, 200
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source") api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding") api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>") api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')

View File

@ -17,13 +17,13 @@ from ..wraps import account_initialization_required
def get_oauth_providers(): def get_oauth_providers():
with current_app.app_context(): with current_app.app_context():
notion_oauth = NotionOAuth( notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID,
client_id=dify_config.NOTION_CLIENT_ID, client_secret=dify_config.NOTION_CLIENT_SECRET,
client_secret=dify_config.NOTION_CLIENT_SECRET, redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion')
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion",
)
OAUTH_PROVIDERS = {"notion": notion_oauth} OAUTH_PROVIDERS = {
'notion': notion_oauth
}
return OAUTH_PROVIDERS return OAUTH_PROVIDERS
@ -37,16 +37,18 @@ class OAuthDataSource(Resource):
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
print(vars(oauth_provider)) print(vars(oauth_provider))
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {'error': 'Invalid provider'}, 400
if dify_config.NOTION_INTEGRATION_TYPE == "internal": if dify_config.NOTION_INTEGRATION_TYPE == 'internal':
internal_secret = dify_config.NOTION_INTERNAL_SECRET internal_secret = dify_config.NOTION_INTERNAL_SECRET
if not internal_secret: if not internal_secret:
return ({"error": "Internal secret is not set"},) return {'error': 'Internal secret is not set'},
oauth_provider.save_internal_access_token(internal_secret) oauth_provider.save_internal_access_token(internal_secret)
return {"data": ""} return { 'data': '' }
else: else:
auth_url = oauth_provider.get_authorization_url() auth_url = oauth_provider.get_authorization_url()
return {"data": auth_url}, 200 return { 'data': auth_url }, 200
class OAuthDataSourceCallback(Resource): class OAuthDataSourceCallback(Resource):
@ -55,18 +57,18 @@ class OAuthDataSourceCallback(Resource):
with current_app.app_context(): with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {'error': 'Invalid provider'}, 400
if "code" in request.args: if 'code' in request.args:
code = request.args.get("code") code = request.args.get('code')
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}") return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}')
elif "error" in request.args: elif 'error' in request.args:
error = request.args.get("error") error = request.args.get('error')
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}") return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}')
else: else:
return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied')
class OAuthDataSourceBinding(Resource): class OAuthDataSourceBinding(Resource):
def get(self, provider: str): def get(self, provider: str):
@ -74,18 +76,17 @@ class OAuthDataSourceBinding(Resource):
with current_app.app_context(): with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {'error': 'Invalid provider'}, 400
if "code" in request.args: if 'code' in request.args:
code = request.args.get("code") code = request.args.get('code')
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
logging.exception( logging.exception(
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
) return {'error': 'OAuth data source process failed'}, 400
return {"error": "OAuth data source process failed"}, 400
return {"result": "success"}, 200 return {'result': 'success'}, 200
class OAuthDataSourceSync(Resource): class OAuthDataSourceSync(Resource):
@ -99,17 +100,18 @@ class OAuthDataSourceSync(Resource):
with current_app.app_context(): with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {'error': 'Invalid provider'}, 400
try: try:
oauth_provider.sync_data_source(binding_id) oauth_provider.sync_data_source(binding_id)
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") logging.exception(
return {"error": "OAuth data source process failed"}, 400 f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {'error': 'OAuth data source process failed'}, 400
return {"result": "success"}, 200 return {'result': 'success'}, 200
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>") api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>") api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>") api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync") api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')

View File

@ -2,30 +2,31 @@ from libs.exception import BaseHTTPException
class ApiKeyAuthFailedError(BaseHTTPException): class ApiKeyAuthFailedError(BaseHTTPException):
error_code = "auth_failed" error_code = 'auth_failed'
description = "{message}" description = "{message}"
code = 500 code = 500
class InvalidEmailError(BaseHTTPException): class InvalidEmailError(BaseHTTPException):
error_code = "invalid_email" error_code = 'invalid_email'
description = "The email address is not valid." description = "The email address is not valid."
code = 400 code = 400
class PasswordMismatchError(BaseHTTPException): class PasswordMismatchError(BaseHTTPException):
error_code = "password_mismatch" error_code = 'password_mismatch'
description = "The passwords do not match." description = "The passwords do not match."
code = 400 code = 400
class InvalidTokenError(BaseHTTPException): class InvalidTokenError(BaseHTTPException):
error_code = "invalid_or_expired_token" error_code = 'invalid_or_expired_token'
description = "The token is invalid or has expired." description = "The token is invalid or has expired."
code = 400 code = 400
class PasswordResetRateLimitExceededError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded" error_code = 'password_reset_rate_limit_exceeded'
description = "Password reset rate limit exceeded. Try again later." description = "Password reset rate limit exceeded. Try again later."
code = 429 code = 429

View File

@ -21,13 +21,14 @@ from services.errors.account import RateLimitExceededError
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json") parser.add_argument('email', type=str, required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
email = args["email"] email = args['email']
if not email_validate(email): if not email_validate(email):
raise InvalidEmailError() raise InvalidEmailError()
@ -48,36 +49,38 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument('token', type=str, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
token = args["token"] token = args['token']
reset_data = AccountService.get_reset_password_data(token) reset_data = AccountService.get_reset_password_data(token)
if reset_data is None: if reset_data is None:
return {"is_valid": False, "email": None} return {'is_valid': False, 'email': None}
return {"is_valid": True, "email": reset_data.get("email")} return {'is_valid': True, 'email': reset_data.get('email')}
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument('token', type=str, required=True, nullable=False, location='json')
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
args = parser.parse_args() args = parser.parse_args()
new_password = args["new_password"] new_password = args['new_password']
password_confirm = args["password_confirm"] password_confirm = args['password_confirm']
if str(new_password).strip() != str(password_confirm).strip(): if str(new_password).strip() != str(password_confirm).strip():
raise PasswordMismatchError() raise PasswordMismatchError()
token = args["token"] token = args['token']
reset_data = AccountService.get_reset_password_data(token) reset_data = AccountService.get_reset_password_data(token)
if reset_data is None: if reset_data is None:
@ -91,14 +94,14 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(new_password, salt) password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first() account = Account.query.filter_by(email=reset_data.get('email')).first()
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
db.session.commit() db.session.commit()
return {"result": "success"} return {'result': 'success'}
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')

View File

@ -20,39 +20,37 @@ class LoginApi(Resource):
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json") parser.add_argument('email', type=email, required=True, location='json')
parser.add_argument("password", type=valid_password, required=True, location="json") parser.add_argument('password', type=valid_password, required=True, location='json')
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") parser.add_argument('remember_me', type=bool, required=False, default=False, location='json')
args = parser.parse_args() args = parser.parse_args()
# todo: Verify the recaptcha # todo: Verify the recaptcha
try: try:
account = AccountService.authenticate(args["email"], args["password"]) account = AccountService.authenticate(args['email'], args['password'])
except services.errors.account.AccountLoginError as e: except services.errors.account.AccountLoginError as e:
return {"code": "unauthorized", "message": str(e)}, 401 return {'code': 'unauthorized', 'message': str(e)}, 401
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0: if len(tenants) == 0:
return { return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
"result": "fail",
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
}
token = AccountService.login(account, ip_address=get_remote_ip(request)) token = AccountService.login(account, ip_address=get_remote_ip(request))
return {"result": "success", "data": token} return {'result': 'success', 'data': token}
class LogoutApi(Resource): class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
account = cast(Account, flask_login.current_user) account = cast(Account, flask_login.current_user)
token = request.headers.get("Authorization", "").split(" ")[1] token = request.headers.get('Authorization', '').split(' ')[1]
AccountService.logout(account=account, token=token) AccountService.logout(account=account, token=token)
flask_login.logout_user() flask_login.logout_user()
return {"result": "success"} return {'result': 'success'}
class ResetPasswordApi(Resource): class ResetPasswordApi(Resource):
@ -82,11 +80,11 @@ class ResetPasswordApi(Resource):
# 'subject': 'Reset your Dify password', # 'subject': 'Reset your Dify password',
# 'html': """ # 'html': """
# <p>Dear User,</p> # <p>Dear User,</p>
# <p>The Dify team has generated a new password for you, details as follows:</p> # <p>The Dify team has generated a new password for you, details as follows:</p>
# <p><strong>{new_password}</strong></p> # <p><strong>{new_password}</strong></p>
# <p>Please change your password to log in as soon as possible.</p> # <p>Please change your password to log in as soon as possible.</p>
# <p>Regards,</p> # <p>Regards,</p>
# <p>The Dify Team</p> # <p>The Dify Team</p>
# """ # """
# } # }
@ -103,8 +101,8 @@ class ResetPasswordApi(Resource):
# # handle error # # handle error
# pass # pass
return {"result": "success"} return {'result': 'success'}
api.add_resource(LoginApi, "/login") api.add_resource(LoginApi, '/login')
api.add_resource(LogoutApi, "/logout") api.add_resource(LogoutApi, '/logout')

View File

@ -25,7 +25,7 @@ def get_oauth_providers():
github_oauth = GitHubOAuth( github_oauth = GitHubOAuth(
client_id=dify_config.GITHUB_CLIENT_ID, client_id=dify_config.GITHUB_CLIENT_ID,
client_secret=dify_config.GITHUB_CLIENT_SECRET, client_secret=dify_config.GITHUB_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github',
) )
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
google_oauth = None google_oauth = None
@ -33,10 +33,10 @@ def get_oauth_providers():
google_oauth = GoogleOAuth( google_oauth = GoogleOAuth(
client_id=dify_config.GOOGLE_CLIENT_ID, client_id=dify_config.GOOGLE_CLIENT_ID,
client_secret=dify_config.GOOGLE_CLIENT_SECRET, client_secret=dify_config.GOOGLE_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google',
) )
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth}
return OAUTH_PROVIDERS return OAUTH_PROVIDERS
@ -47,7 +47,7 @@ class OAuthLogin(Resource):
oauth_provider = OAUTH_PROVIDERS.get(provider) oauth_provider = OAUTH_PROVIDERS.get(provider)
print(vars(oauth_provider)) print(vars(oauth_provider))
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {'error': 'Invalid provider'}, 400
auth_url = oauth_provider.get_authorization_url() auth_url = oauth_provider.get_authorization_url()
return redirect(auth_url) return redirect(auth_url)
@ -59,20 +59,20 @@ class OAuthCallback(Resource):
with current_app.app_context(): with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider) oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {'error': 'Invalid provider'}, 400
code = request.args.get("code") code = request.args.get('code')
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}')
return {"error": "OAuth process failed"}, 400 return {'error': 'OAuth process failed'}, 400
account = _generate_account(provider, user_info) account = _generate_account(provider, user_info)
# Check account status # Check account status
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
return {"error": "Account is banned or closed."}, 403 return {'error': 'Account is banned or closed.'}, 403
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
@ -83,7 +83,7 @@ class OAuthCallback(Resource):
token = AccountService.login(account, ip_address=get_remote_ip(request)) token = AccountService.login(account, ip_address=get_remote_ip(request))
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}')
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
if not account: if not account:
# Create account # Create account
account_name = user_info.name or "Dify" account_name = user_info.name if user_info.name else 'Dify'
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
) )
@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
return account return account
api.add_resource(OAuthLogin, "/oauth/login/<provider>") api.add_resource(OAuthLogin, '/oauth/login/<provider>')
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") api.add_resource(OAuthCallback, '/oauth/authorize/<provider>')

View File

@ -9,24 +9,28 @@ from services.billing_service import BillingService
class Subscription(Resource): class Subscription(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@only_edition_cloud @only_edition_cloud
def get(self): def get(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team'])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year'])
args = parser.parse_args() args = parser.parse_args()
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription( return BillingService.get_subscription(args['plan'],
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id args['interval'],
) current_user.email,
current_user.current_tenant_id)
class Invoices(Resource): class Invoices(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -36,5 +40,5 @@ class Invoices(Resource):
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
api.add_resource(Subscription, "/billing/subscription") api.add_resource(Subscription, '/billing/subscription')
api.add_resource(Invoices, "/billing/invoices") api.add_resource(Invoices, '/billing/invoices')

View File

@ -22,22 +22,19 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task
class DataSourceApi(Resource): class DataSourceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(integrate_list_fields) @marshal_with(integrate_list_fields)
def get(self): def get(self):
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = ( data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
db.session.query(DataSourceOauthBinding) DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
.filter( DataSourceOauthBinding.disabled == False
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, ).all()
DataSourceOauthBinding.disabled == False,
)
.all()
)
base_url = request.url_root.rstrip("/") base_url = request.url_root.rstrip('/')
data_source_oauth_base_path = "/console/api/oauth/data-source" data_source_oauth_base_path = "/console/api/oauth/data-source"
providers = ["notion"] providers = ["notion"]
@ -47,30 +44,26 @@ class DataSourceApi(Resource):
existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates) existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates)
if existing_integrates: if existing_integrates:
for existing_integrate in list(existing_integrates): for existing_integrate in list(existing_integrates):
integrate_data.append( integrate_data.append({
{ 'id': existing_integrate.id,
"id": existing_integrate.id, 'provider': provider,
"provider": provider, 'created_at': existing_integrate.created_at,
"created_at": existing_integrate.created_at, 'is_bound': True,
"is_bound": True, 'disabled': existing_integrate.disabled,
"disabled": existing_integrate.disabled, 'source_info': existing_integrate.source_info,
"source_info": existing_integrate.source_info, 'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
"link": f"{base_url}{data_source_oauth_base_path}/{provider}", })
}
)
else: else:
integrate_data.append( integrate_data.append({
{ 'id': None,
"id": None, 'provider': provider,
"provider": provider, 'created_at': None,
"created_at": None, 'source_info': None,
"source_info": None, 'is_bound': False,
"is_bound": False, 'disabled': None,
"disabled": None, 'link': f'{base_url}{data_source_oauth_base_path}/{provider}'
"link": f"{base_url}{data_source_oauth_base_path}/{provider}", })
} return {'data': integrate_data}, 200
)
return {"data": integrate_data}, 200
@setup_required @setup_required
@login_required @login_required
@ -78,82 +71,92 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action): def patch(self, binding_id, action):
binding_id = str(binding_id) binding_id = str(binding_id)
action = str(action) action = str(action)
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() data_source_binding = DataSourceOauthBinding.query.filter_by(
id=binding_id
).first()
if data_source_binding is None: if data_source_binding is None:
raise NotFound("Data source binding not found.") raise NotFound('Data source binding not found.')
# enable binding # enable binding
if action == "enable": if action == 'enable':
if data_source_binding.disabled: if data_source_binding.disabled:
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:
raise ValueError("Data source is not disabled.") raise ValueError('Data source is not disabled.')
# disable binding # disable binding
if action == "disable": if action == 'disable':
if not data_source_binding.disabled: if not data_source_binding.disabled:
data_source_binding.disabled = True data_source_binding.disabled = True
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:
raise ValueError("Data source is disabled.") raise ValueError('Data source is disabled.')
return {"result": "success"}, 200 return {'result': 'success'}, 200
class DataSourceNotionListApi(Resource): class DataSourceNotionListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(integrate_notion_info_list_fields) @marshal_with(integrate_notion_info_list_fields)
def get(self): def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str) dataset_id = request.args.get('dataset_id', default=None, type=str)
exist_page_ids = [] exist_page_ids = []
# import notion in the exist dataset # import notion in the exist dataset
if dataset_id: if dataset_id:
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound('Dataset not found.')
if dataset.data_source_type != "notion_import": if dataset.data_source_type != 'notion_import':
raise ValueError("Dataset is not notion type.") raise ValueError('Dataset is not notion type.')
documents = Document.query.filter_by( documents = Document.query.filter_by(
dataset_id=dataset_id, dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
data_source_type="notion_import", data_source_type='notion_import',
enabled=True, enabled=True
).all() ).all()
if documents: if documents:
for document in documents: for document in documents:
data_source_info = json.loads(document.data_source_info) data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"]) exist_page_ids.append(data_source_info['notion_page_id'])
# get all authorized pages # get all authorized pages
data_source_bindings = DataSourceOauthBinding.query.filter_by( data_source_bindings = DataSourceOauthBinding.query.filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False tenant_id=current_user.current_tenant_id,
provider='notion',
disabled=False
).all() ).all()
if not data_source_bindings: if not data_source_bindings:
return {"notion_info": []}, 200 return {
'notion_info': []
}, 200
pre_import_info_list = [] pre_import_info_list = []
for data_source_binding in data_source_bindings: for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info source_info = data_source_binding.source_info
pages = source_info["pages"] pages = source_info['pages']
# Filter out already bound pages # Filter out already bound pages
for page in pages: for page in pages:
if page["page_id"] in exist_page_ids: if page['page_id'] in exist_page_ids:
page["is_bound"] = True page['is_bound'] = True
else: else:
page["is_bound"] = False page['is_bound'] = False
pre_import_info = { pre_import_info = {
"workspace_name": source_info["workspace_name"], 'workspace_name': source_info['workspace_name'],
"workspace_icon": source_info["workspace_icon"], 'workspace_icon': source_info['workspace_icon'],
"workspace_id": source_info["workspace_id"], 'workspace_id': source_info['workspace_id'],
"pages": pages, 'pages': pages,
} }
pre_import_info_list.append(pre_import_info) pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200 return {
'notion_info': pre_import_info_list
}, 200
class DataSourceNotionApi(Resource): class DataSourceNotionApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -163,67 +166,64 @@ class DataSourceNotionApi(Resource):
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = DataSourceOauthBinding.query.filter(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == 'notion',
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
) )
).first() ).first()
if not data_source_binding: if not data_source_binding:
raise NotFound("Data source binding not found.") raise NotFound('Data source binding not found.')
extractor = NotionExtractor( extractor = NotionExtractor(
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=data_source_binding.access_token, notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id
) )
text_docs = extractor.extract() text_docs = extractor.extract()
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200 return {
'content': "\n".join([doc.page_content for doc in text_docs])
}, 200
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json')
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument( parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
notion_info_list = args["notion_info_list"] notion_info_list = args['notion_info_list']
extract_settings = [] extract_settings = []
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"] workspace_id = notion_info['workspace_id']
for page in notion_info["pages"]: for page in notion_info['pages']:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="notion_import", datasource_type="notion_import",
notion_info={ notion_info={
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page['page_id'],
"notion_page_type": page["type"], "notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_user.current_tenant_id
}, },
document_model=args["doc_form"], document_model=args['doc_form']
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
current_user.current_tenant_id, args['process_rule'], args['doc_form'],
extract_settings, args['doc_language'])
args["process_rule"],
args["doc_form"],
args["doc_language"],
)
return response, 200 return response, 200
class DataSourceNotionDatasetSyncApi(Resource): class DataSourceNotionDatasetSyncApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -240,6 +240,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
class DataSourceNotionDocumentSyncApi(Resource): class DataSourceNotionDocumentSyncApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -257,14 +258,10 @@ class DataSourceNotionDocumentSyncApi(Resource):
return 200 return 200
api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>") api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>')
api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages") api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages')
api.add_resource( api.add_resource(DataSourceNotionApi,
DataSourceNotionApi, '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview',
"/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview", '/datasets/notion-indexing-estimate')
"/datasets/notion-indexing-estimate", api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync')
) api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync')
api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync")
api.add_resource(
DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync"
)

View File

@ -18,53 +18,58 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrival_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import login_required from libs.login import login_required
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from models.model import ApiToken, UploadFile from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _validate_name(name): def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.") raise ValueError('Name must be between 1 to 40 characters.')
return name return name
def _validate_description_length(description): def _validate_description_length(description):
if len(description) > 400: if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.") raise ValueError('Description cannot exceed 400 characters.')
return description return description
class DatasetListApi(Resource): class DatasetListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
page = request.args.get("page", default=1, type=int) page = request.args.get('page', default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get('limit', default=20, type=int)
ids = request.args.getlist("ids") ids = request.args.getlist('ids')
provider = request.args.get("provider", default="vendor") provider = request.args.get('provider', default="vendor")
search = request.args.get("keyword", default=None, type=str) search = request.args.get('keyword', default=None, type=str)
tag_ids = request.args.getlist("tag_ids") tag_ids = request.args.getlist('tag_ids')
if ids: if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
else: else:
datasets, total = DatasetService.get_datasets( datasets, total = DatasetService.get_datasets(page, limit, provider,
page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids current_user.current_tenant_id, current_user, search, tag_ids)
)
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = [] model_names = []
for embedding_model in embedding_models: for embedding_model in embedding_models:
@ -72,22 +77,28 @@ class DatasetListApi(Resource):
data = marshal(datasets, dataset_detail_fields) data = marshal(datasets, dataset_detail_fields)
for item in data: for item in data:
if item["indexing_technique"] == "high_quality": if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names: if item_model in model_names:
item["embedding_available"] = True item['embedding_available'] = True
else: else:
item["embedding_available"] = False item['embedding_available'] = False
else: else:
item["embedding_available"] = True item['embedding_available'] = True
if item.get("permission") == "partial_members": if item.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id'])
item.update({"partial_member_list": part_users_list}) item.update({'partial_member_list': part_users_list})
else: else:
item.update({"partial_member_list": []}) item.update({'partial_member_list': []})
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} response = {
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200 return response, 200
@setup_required @setup_required
@ -95,21 +106,13 @@ class DatasetListApi(Resource):
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument('name', nullable=False, required=True,
"name", help='type is required. Name must be between 1 to 40 characters.',
nullable=False, type=_validate_name)
required=True, parser.add_argument('indexing_technique', type=str, location='json',
help="type is required. Name must be between 1 to 40 characters.", choices=Dataset.INDEXING_TECHNIQUE_LIST,
type=_validate_name, nullable=True,
) help='Invalid indexing technique.')
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
args = parser.parse_args() args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@ -119,10 +122,9 @@ class DatasetListApi(Resource):
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
name=args["name"], name=args['name'],
indexing_technique=args["indexing_technique"], indexing_technique=args['indexing_technique'],
account=current_user, account=current_user
permission=DatasetPermissionEnum.ONLY_ME,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -140,36 +142,42 @@ class DatasetApi(Resource):
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
try: try:
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(
dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields) data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members": if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list}) data.update({'partial_member_list': part_users_list})
# check embedding setting # check embedding setting
provider_manager = ProviderManager() provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) configurations = provider_manager.get_configurations(
tenant_id=current_user.current_tenant_id
)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) embedding_models = configurations.get_models(
model_type=ModelType.TEXT_EMBEDDING,
only_active=True
)
model_names = [] model_names = []
for embedding_model in embedding_models: for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality": if data['indexing_technique'] == 'high_quality':
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names: if item_model in model_names:
data["embedding_available"] = True data['embedding_available'] = True
else: else:
data["embedding_available"] = False data['embedding_available'] = False
else: else:
data["embedding_available"] = True data['embedding_available'] = True
if data.get("permission") == "partial_members": if data.get('permission') == 'partial_members':
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list}) data.update({'partial_member_list': part_users_list})
return data, 200 return data, 200
@ -183,49 +191,42 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument('name', nullable=False,
"name", help='type is required. Name must be between 1 to 40 characters.',
nullable=False, type=_validate_name)
help="type is required. Name must be between 1 to 40 characters.", parser.add_argument('description',
type=_validate_name, location='json', store_missing=False,
) type=_validate_description_length)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) parser.add_argument('indexing_technique', type=str, location='json',
parser.add_argument( choices=Dataset.INDEXING_TECHNIQUE_LIST,
"indexing_technique", nullable=True,
type=str, help='Invalid indexing technique.')
location="json", parser.add_argument('permission', type=str, location='json', choices=(
choices=Dataset.INDEXING_TECHNIQUE_LIST, 'only_me', 'all_team_members', 'partial_members'), help='Invalid permission.'
nullable=True, )
help="Invalid indexing technique.", parser.add_argument('embedding_model', type=str,
) location='json', help='Invalid embedding model.')
parser.add_argument( parser.add_argument('embedding_model_provider', type=str,
"permission", location='json', help='Invalid embedding model provider.')
type=str, parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
location="json", parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.')
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
args = parser.parse_args() args = parser.parse_args()
data = request.get_json() data = request.get_json()
# check embedding model setting # check embedding model setting
if data.get("indexing_technique") == "high_quality": if data.get('indexing_technique') == 'high_quality':
DatasetService.check_embedding_model_setting( DatasetService.check_embedding_model_setting(dataset.tenant_id,
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") data.get('embedding_model_provider'),
) data.get('embedding_model')
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission( DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list") current_user, dataset, data.get('permission'), data.get('partial_member_list')
) )
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) dataset = DatasetService.update_dataset(
dataset_id_str, args, current_user)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -233,19 +234,16 @@ class DatasetApi(Resource):
result_data = marshal(dataset, dataset_detail_fields) result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members": if data.get('partial_member_list') and data.get('permission') == 'partial_members':
DatasetPermissionService.update_partial_member_list( DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list") tenant_id, dataset_id_str, data.get('partial_member_list')
) )
# clear partial member list when permission is only_me or all_team_members # clear partial member list when permission is only_me or all_team_members
elif ( elif data.get('permission') == 'only_me' or data.get('permission') == 'all_team_members':
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str) DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list}) result_data.update({'partial_member_list': partial_member_list})
return result_data, 200 return result_data, 200
@ -262,13 +260,12 @@ class DatasetApi(Resource):
try: try:
if DatasetService.delete_dataset(dataset_id_str, current_user): if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str) DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204 return {'result': 'success'}, 204
else: else:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError: except services.errors.dataset.DatasetInUseError:
raise DatasetInUseError() raise DatasetInUseError()
class DatasetUseCheckApi(Resource): class DatasetUseCheckApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -277,10 +274,10 @@ class DatasetUseCheckApi(Resource):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
return {"is_using": dataset_is_using}, 200 return {'is_using': dataset_is_using}, 200
class DatasetQueryApi(Resource): class DatasetQueryApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -295,53 +292,51 @@ class DatasetQueryApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
page = request.args.get("page", default=1, type=int) page = request.args.get('page', default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get('limit', default=20, type=int)
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) dataset_queries, total = DatasetService.get_dataset_queries(
dataset_id=dataset.id,
page=page,
per_page=limit
)
response = { response = {
"data": marshal(dataset_queries, dataset_query_detail_fields), 'data': marshal(dataset_queries, dataset_query_detail_fields),
"has_more": len(dataset_queries) == limit, 'has_more': len(dataset_queries) == limit,
"limit": limit, 'limit': limit,
"total": total, 'total': total,
"page": page, 'page': page
} }
return response, 200 return response, 200
class DatasetIndexingEstimateApi(Resource): class DatasetIndexingEstimateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
parser.add_argument( parser.add_argument('indexing_technique', type=str, required=True,
"indexing_technique", choices=Dataset.INDEXING_TECHNIQUE_LIST,
type=str, nullable=True, location='json')
required=True, parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
choices=Dataset.INDEXING_TECHNIQUE_LIST, parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
nullable=True, parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location="json", location='json')
)
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
# validate args # validate args
DocumentService.estimate_args_validate(args) DocumentService.estimate_args_validate(args)
extract_settings = [] extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file": if args['info_list']['data_source_type'] == 'upload_file':
file_ids = args["info_list"]["file_info_list"]["file_ids"] file_ids = args['info_list']['file_info_list']['file_ids']
file_details = ( file_details = db.session.query(UploadFile).filter(
db.session.query(UploadFile) UploadFile.tenant_id == current_user.current_tenant_id,
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) UploadFile.id.in_(file_ids)
.all() ).all()
)
if file_details is None: if file_details is None:
raise NotFound("File not found.") raise NotFound("File not found.")
@ -349,58 +344,55 @@ class DatasetIndexingEstimateApi(Resource):
if file_details: if file_details:
for file_detail in file_details: for file_detail in file_details:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] datasource_type="upload_file",
upload_file=file_detail,
document_model=args['doc_form']
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import": elif args['info_list']['data_source_type'] == 'notion_import':
notion_info_list = args["info_list"]["notion_info_list"] notion_info_list = args['info_list']['notion_info_list']
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"] workspace_id = notion_info['workspace_id']
for page in notion_info["pages"]: for page in notion_info['pages']:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="notion_import", datasource_type="notion_import",
notion_info={ notion_info={
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"], "notion_obj_id": page['page_id'],
"notion_page_type": page["type"], "notion_page_type": page['type'],
"tenant_id": current_user.current_tenant_id, "tenant_id": current_user.current_tenant_id
}, },
document_model=args["doc_form"], document_model=args['doc_form']
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "website_crawl": elif args['info_list']['data_source_type'] == 'website_crawl':
website_info_list = args["info_list"]["website_info_list"] website_info_list = args['info_list']['website_info_list']
for url in website_info_list["urls"]: for url in website_info_list['urls']:
extract_setting = ExtractSetting( extract_setting = ExtractSetting(
datasource_type="website_crawl", datasource_type="website_crawl",
website_info={ website_info={
"provider": website_info_list["provider"], "provider": website_info_list['provider'],
"job_id": website_info_list["job_id"], "job_id": website_info_list['job_id'],
"url": url, "url": url,
"tenant_id": current_user.current_tenant_id, "tenant_id": current_user.current_tenant_id,
"mode": "crawl", "mode": 'crawl',
"only_main_content": website_info_list["only_main_content"], "only_main_content": website_info_list['only_main_content']
}, },
document_model=args["doc_form"], document_model=args['doc_form']
) )
extract_settings.append(extract_setting) extract_settings.append(extract_setting)
else: else:
raise ValueError("Data source type not support") raise ValueError('Data source type not support')
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings,
current_user.current_tenant_id, args['process_rule'], args['doc_form'],
extract_settings, args['doc_language'], args['dataset_id'],
args["process_rule"], args['indexing_technique'])
args["doc_form"],
args["doc_language"],
args["dataset_id"],
args["indexing_technique"],
)
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." "No Embedding Model available. Please configure a valid provider "
) "in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except Exception as e: except Exception as e:
@ -410,6 +402,7 @@ class DatasetIndexingEstimateApi(Resource):
class DatasetRelatedAppListApi(Resource): class DatasetRelatedAppListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -433,52 +426,52 @@ class DatasetRelatedAppListApi(Resource):
if app_model: if app_model:
related_apps.append(app_model) related_apps.append(app_model)
return {"data": related_apps, "total": len(related_apps)}, 200 return {
'data': related_apps,
'total': len(related_apps)
}, 200
class DatasetIndexingStatusApi(Resource): class DatasetIndexingStatusApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id): def get(self, dataset_id):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
documents = ( documents = db.session.query(Document).filter(
db.session.query(Document) Document.dataset_id == dataset_id,
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) Document.tenant_id == current_user.current_tenant_id
.all() ).all()
)
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id),
DocumentSegment.document_id == str(document.id), DocumentSegment.status != 're_segment').count()
DocumentSegment.status != "re_segment", total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
).count() DocumentSegment.status != 're_segment').count()
total_segments = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
).count()
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields)) documents_status.append(marshal(document, document_status_fields))
data = {"data": documents_status} data = {
'data': documents_status
}
return data return data
class DatasetApiKeyApi(Resource): class DatasetApiKeyApi(Resource):
max_keys = 10 max_keys = 10
token_prefix = "dataset-" token_prefix = 'dataset-'
resource_type = "dataset" resource_type = 'dataset'
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(api_key_list) @marshal_with(api_key_list)
def get(self): def get(self):
keys = ( keys = db.session.query(ApiToken). \
db.session.query(ApiToken) filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) all()
.all()
)
return {"items": keys} return {"items": keys}
@setup_required @setup_required
@ -490,17 +483,15 @@ class DatasetApiKeyApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
current_key_count = ( current_key_count = db.session.query(ApiToken). \
db.session.query(ApiToken) filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) count()
.count()
)
if current_key_count >= self.max_keys: if current_key_count >= self.max_keys:
flask_restful.abort( flask_restful.abort(
400, 400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded", code='max_keys_exceeded'
) )
key = ApiToken.generate_api_key(self.token_prefix, 24) key = ApiToken.generate_api_key(self.token_prefix, 24)
@ -514,7 +505,7 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource): class DatasetApiDeleteApi(Resource):
resource_type = "dataset" resource_type = 'dataset'
@setup_required @setup_required
@login_required @login_required
@ -526,23 +517,18 @@ class DatasetApiDeleteApi(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
key = ( key = db.session.query(ApiToken). \
db.session.query(ApiToken) filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
.filter( ApiToken.id == api_key_id). \
ApiToken.tenant_id == current_user.current_tenant_id, first()
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
.first()
)
if key is None: if key is None:
flask_restful.abort(404, message="API key not found") flask_restful.abort(404, message='API key not found')
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {'result': 'success'}, 204
class DatasetApiBaseUrlApi(Resource): class DatasetApiBaseUrlApi(Resource):
@ -550,7 +536,10 @@ class DatasetApiBaseUrlApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} return {
'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL
else request.host_url.rstrip('/')) + '/v1'
}
class DatasetRetrievalSettingApi(Resource): class DatasetRetrievalSettingApi(Resource):
@ -560,26 +549,15 @@ class DatasetRetrievalSettingApi(Resource):
def get(self): def get(self):
vector_type = dify_config.VECTOR_STORE vector_type = dify_config.VECTOR_STORE
match vector_type: match vector_type:
case ( case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
VectorType.MILVUS
| VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
):
return { return {
"retrieval_method": [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value,
@ -595,27 +573,15 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, vector_type): def get(self, vector_type):
match vector_type: match vector_type:
case ( case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
VectorType.MILVUS
| VectorType.RELYT
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
VectorType.QDRANT
| VectorType.WEAVIATE
| VectorType.OPENSEARCH
| VectorType.ANALYTICDB
| VectorType.MYSCALE
| VectorType.ORACLE
| VectorType.ELASTICSEARCH
| VectorType.PGVECTOR
):
return { return {
"retrieval_method": [ 'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value
]
}
case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH| VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.SEMANTIC_SEARCH.value,
RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value,
RetrievalMethod.HYBRID_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value,
@ -625,6 +591,7 @@ class DatasetRetrievalSettingMockApi(Resource):
raise ValueError(f"Unsupported vector db type {vector_type}.") raise ValueError(f"Unsupported vector db type {vector_type}.")
class DatasetErrorDocs(Resource): class DatasetErrorDocs(Resource):
@setup_required @setup_required
@login_required @login_required
@ -636,7 +603,10 @@ class DatasetErrorDocs(Resource):
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 return {
'data': [marshal(item, document_status_fields) for item in results],
'total': len(results)
}, 200
class DatasetPermissionUserListApi(Resource): class DatasetPermissionUserListApi(Resource):
@ -656,21 +626,21 @@ class DatasetPermissionUserListApi(Resource):
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return { return {
"data": partial_members_list, 'data': partial_members_list,
}, 200 }, 200
api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check") api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check')
api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries") api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs") api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps") api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status") api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>") api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>") api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users") api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users')

File diff suppressed because it is too large Load Diff

View File

@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource):
document_id = str(document_id) document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound('Dataset not found.')
try: try:
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
@ -50,33 +50,37 @@ class DatasetDocumentSegmentListApi(Resource):
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound('Document not found.')
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("last_id", type=str, default=None, location="args") parser.add_argument('last_id', type=str, default=None, location='args')
parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument('limit', type=int, default=20, location='args')
parser.add_argument("status", type=str, action="append", default=[], location="args") parser.add_argument('status', type=str,
parser.add_argument("hit_count_gte", type=int, default=None, location="args") action='append', default=[], location='args')
parser.add_argument("enabled", type=str, default="all", location="args") parser.add_argument('hit_count_gte', type=int,
parser.add_argument("keyword", type=str, default=None, location="args") default=None, location='args')
parser.add_argument('enabled', type=str, default='all', location='args')
parser.add_argument('keyword', type=str, default=None, location='args')
args = parser.parse_args() args = parser.parse_args()
last_id = args["last_id"] last_id = args['last_id']
limit = min(args["limit"], 100) limit = min(args['limit'], 100)
status_list = args["status"] status_list = args['status']
hit_count_gte = args["hit_count_gte"] hit_count_gte = args['hit_count_gte']
keyword = args["keyword"] keyword = args['keyword']
query = DocumentSegment.query.filter( query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
) )
if last_id is not None: if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id)) last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment: if last_segment:
query = query.filter(DocumentSegment.position > last_segment.position) query = query.filter(
DocumentSegment.position > last_segment.position)
else: else:
return {"data": [], "has_more": False, "limit": limit}, 200 return {'data': [], 'has_more': False, 'limit': limit}, 200
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.filter(DocumentSegment.status.in_(status_list))
@ -85,12 +89,12 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.filter(DocumentSegment.hit_count >= hit_count_gte) query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
if keyword: if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
if args["enabled"].lower() != "all": if args['enabled'].lower() != 'all':
if args["enabled"].lower() == "true": if args['enabled'].lower() == 'true':
query = query.filter(DocumentSegment.enabled == True) query = query.filter(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false": elif args['enabled'].lower() == 'false':
query = query.filter(DocumentSegment.enabled == False) query = query.filter(DocumentSegment.enabled == False)
total = query.count() total = query.count()
@ -102,11 +106,11 @@ class DatasetDocumentSegmentListApi(Resource):
segments = segments[:-1] segments = segments[:-1]
return { return {
"data": marshal(segments, segment_fields), 'data': marshal(segments, segment_fields),
"doc_form": document.doc_form, 'doc_form': document.doc_form,
"has_more": has_more, 'has_more': has_more,
"limit": limit, 'limit': limit,
"total": total, 'total': total
}, 200 }, 200
@ -114,12 +118,12 @@ class DatasetDocumentSegmentApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check('vector_space')
def patch(self, dataset_id, segment_id, action): def patch(self, dataset_id, segment_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound('Dataset not found.')
# check user's model setting # check user's model setting
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -130,7 +134,7 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == 'high_quality':
# check embedding model setting # check embedding model setting
try: try:
model_manager = ModelManager() model_manager = ModelManager()
@ -138,32 +142,32 @@ class DatasetDocumentSegmentApi(Resource):
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider." "in the Settings -> Model Provider.")
)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter( segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first() ).first()
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound('Segment not found.')
if segment.status != "completed": if segment.status != 'completed':
raise NotFound("Segment is not completed, enable or disable function is not allowed") raise NotFound('Segment is not completed, enable or disable function is not allowed')
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
cache_result = redis_client.get(document_indexing_cache_key) cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later") raise InvalidActionError("Document is being indexed, please try again later")
indexing_cache_key = "segment_{}_indexing".format(segment.id) indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise InvalidActionError("Segment is being indexed, please try again later") raise InvalidActionError("Segment is being indexed, please try again later")
@ -182,7 +186,7 @@ class DatasetDocumentSegmentApi(Resource):
enable_segment_to_index_task.delay(segment.id) enable_segment_to_index_task.delay(segment.id)
return {"result": "success"}, 200 return {'result': 'success'}, 200
elif action == "disable": elif action == "disable":
if not segment.enabled: if not segment.enabled:
raise InvalidActionError("Segment is already disabled.") raise InvalidActionError("Segment is already disabled.")
@ -197,7 +201,7 @@ class DatasetDocumentSegmentApi(Resource):
disable_segment_from_index_task.delay(segment.id) disable_segment_from_index_task.delay(segment.id)
return {"result": "success"}, 200 return {'result': 'success'}, 200
else: else:
raise InvalidActionError() raise InvalidActionError()
@ -206,36 +210,35 @@ class DatasetDocumentSegmentAddApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check('add_segment')
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound('Dataset not found.')
# check document # check document
document_id = str(document_id) document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound('Document not found.')
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# check embedding model setting # check embedding model setting
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == 'high_quality':
try: try:
model_manager = ModelManager() model_manager = ModelManager()
model_manager.get_model_instance( model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider." "in the Settings -> Model Provider.")
)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
try: try:
@ -244,34 +247,37 @@ class DatasetDocumentSegmentAddApi(Resource):
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset) segment = SegmentService.create_segment(args, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
class DatasetDocumentSegmentUpdateApi(Resource): class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check('vector_space')
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound('Dataset not found.')
# check user's model setting # check user's model setting
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
# check document # check document
document_id = str(document_id) document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound('Document not found.')
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == 'high_quality':
# check embedding model setting # check embedding model setting
try: try:
model_manager = ModelManager() model_manager = ModelManager()
@ -279,22 +285,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider, provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider." "in the Settings -> Model Provider.")
)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first() ).first()
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@ -304,13 +310,16 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise Forbidden(str(e)) raise Forbidden(str(e))
# validate args # validate args
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset) segment = SegmentService.update_segment(args, segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
@setup_required @setup_required
@login_required @login_required
@ -320,21 +329,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound('Dataset not found.')
# check user's model setting # check user's model setting
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
# check document # check document
document_id = str(document_id) document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound('Document not found.')
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first() ).first()
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound('Segment not found.')
# The role of the current user in the ta table must be admin or owner # The role of the current user in the ta table must be admin or owner
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@ -343,36 +353,36 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset) SegmentService.delete_segment(segment, document, dataset)
return {"result": "success"}, 200 return {'result': 'success'}, 200
class DatasetDocumentSegmentBatchImportApi(Resource): class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check('add_segment')
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound('Dataset not found.')
# check document # check document
document_id = str(document_id) document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id) document = DocumentService.get_document(dataset_id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound('Document not found.')
# get file from request # get file from request
file = request.files["file"] file = request.files['file']
# check file # check file
if "file" not in request.files: if 'file' not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
# check file type # check file type
if not file.filename.endswith(".csv"): if not file.filename.endswith('.csv'):
raise ValueError("Invalid file type. Only CSV files are allowed") raise ValueError("Invalid file type. Only CSV files are allowed")
try: try:
@ -380,47 +390,51 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
df = pd.read_csv(file) df = pd.read_csv(file)
result = [] result = []
for index, row in df.iterrows(): for index, row in df.iterrows():
if document.doc_form == "qa_model": if document.doc_form == 'qa_model':
data = {"content": row[0], "answer": row[1]} data = {'content': row[0], 'answer': row[1]}
else: else:
data = {"content": row[0]} data = {'content': row[0]}
result.append(data) result.append(data)
if len(result) == 0: if len(result) == 0:
raise ValueError("The CSV file is empty.") raise ValueError("The CSV file is empty.")
# async job # async job
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
indexing_cache_key = "segment_batch_import_{}".format(str(job_id)) indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
# send batch add segments task # send batch add segments task
redis_client.setnx(indexing_cache_key, "waiting") redis_client.setnx(indexing_cache_key, 'waiting')
batch_create_segment_to_index_task.delay( batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id current_user.current_tenant_id, current_user.id)
)
except Exception as e: except Exception as e:
return {"error": str(e)}, 500 return {'error': str(e)}, 500
return {"job_id": job_id, "job_status": "waiting"}, 200 return {
'job_id': job_id,
'job_status': 'waiting'
}, 200
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, job_id): def get(self, job_id):
job_id = str(job_id) job_id = str(job_id)
indexing_cache_key = "segment_batch_import_{}".format(job_id) indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job is not exist.")
return {"job_id": job_id, "job_status": cache_result.decode()}, 200 return {
'job_id': job_id,
'job_status': cache_result.decode()
}, 200
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") api.add_resource(DatasetDocumentSegmentListApi,
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>") '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment") api.add_resource(DatasetDocumentSegmentApi,
api.add_resource( '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
DatasetDocumentSegmentUpdateApi, api.add_resource(DatasetDocumentSegmentAddApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>", '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
) api.add_resource(DatasetDocumentSegmentUpdateApi,
api.add_resource( '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
DatasetDocumentSegmentBatchImportApi, api.add_resource(DatasetDocumentSegmentBatchImportApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import", '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
"/datasets/batch_import_status/<uuid:job_id>", '/datasets/batch_import_status/<uuid:job_id>')
)

View File

@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException): class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded" error_code = 'no_file_uploaded'
description = "Please upload your file." description = "Please upload your file."
code = 400 code = 400
class TooManyFilesError(BaseHTTPException): class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files" error_code = 'too_many_files'
description = "Only one file is allowed." description = "Only one file is allowed."
code = 400 code = 400
class FileTooLargeError(BaseHTTPException): class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large" error_code = 'file_too_large'
description = "File size exceeded. {message}" description = "File size exceeded. {message}"
code = 413 code = 413
class UnsupportedFileTypeError(BaseHTTPException): class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type" error_code = 'unsupported_file_type'
description = "File type not allowed." description = "File type not allowed."
code = 415 code = 415
class HighQualityDatasetOnlyError(BaseHTTPException): class HighQualityDatasetOnlyError(BaseHTTPException):
error_code = "high_quality_dataset_only" error_code = 'high_quality_dataset_only'
description = "Current operation only supports 'high-quality' datasets." description = "Current operation only supports 'high-quality' datasets."
code = 400 code = 400
class DatasetNotInitializedError(BaseHTTPException): class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized" error_code = 'dataset_not_initialized'
description = "The dataset is still being initialized or indexing. Please wait a moment." description = "The dataset is still being initialized or indexing. Please wait a moment."
code = 400 code = 400
class ArchivedDocumentImmutableError(BaseHTTPException): class ArchivedDocumentImmutableError(BaseHTTPException):
error_code = "archived_document_immutable" error_code = 'archived_document_immutable'
description = "The archived document is not editable." description = "The archived document is not editable."
code = 403 code = 403
class DatasetNameDuplicateError(BaseHTTPException): class DatasetNameDuplicateError(BaseHTTPException):
error_code = "dataset_name_duplicate" error_code = 'dataset_name_duplicate'
description = "The dataset name already exists. Please modify your dataset name." description = "The dataset name already exists. Please modify your dataset name."
code = 409 code = 409
class InvalidActionError(BaseHTTPException): class InvalidActionError(BaseHTTPException):
error_code = "invalid_action" error_code = 'invalid_action'
description = "Invalid action." description = "Invalid action."
code = 400 code = 400
class DocumentAlreadyFinishedError(BaseHTTPException): class DocumentAlreadyFinishedError(BaseHTTPException):
error_code = "document_already_finished" error_code = 'document_already_finished'
description = "The document has been processed. Please refresh the page or go to the document details." description = "The document has been processed. Please refresh the page or go to the document details."
code = 400 code = 400
class DocumentIndexingError(BaseHTTPException): class DocumentIndexingError(BaseHTTPException):
error_code = "document_indexing" error_code = 'document_indexing'
description = "The document is being processed and cannot be edited." description = "The document is being processed and cannot be edited."
code = 400 code = 400
class InvalidMetadataError(BaseHTTPException): class InvalidMetadataError(BaseHTTPException):
error_code = "invalid_metadata" error_code = 'invalid_metadata'
description = "The metadata content is incorrect. Please check and verify." description = "The metadata content is incorrect. Please check and verify."
code = 400 code = 400
class WebsiteCrawlError(BaseHTTPException): class WebsiteCrawlError(BaseHTTPException):
error_code = "crawl_failed" error_code = 'crawl_failed'
description = "{message}" description = "{message}"
code = 500 code = 500
class DatasetInUseError(BaseHTTPException): class DatasetInUseError(BaseHTTPException):
error_code = "dataset_in_use" error_code = 'dataset_in_use'
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
code = 409 code = 409
class IndexingEstimateError(BaseHTTPException): class IndexingEstimateError(BaseHTTPException):
error_code = "indexing_estimate_error" error_code = 'indexing_estimate_error'
description = "Knowledge indexing estimate failed: {message}" description = "Knowledge indexing estimate failed: {message}"
code = 500 code = 500

View File

@ -21,6 +21,7 @@ PREVIEW_WORDS_LIMIT = 3000
class FileApi(Resource): class FileApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -30,22 +31,23 @@ class FileApi(Resource):
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
return { return {
"file_size_limit": file_size_limit, 'file_size_limit': file_size_limit,
"batch_count_limit": batch_count_limit, 'batch_count_limit': batch_count_limit,
"image_file_size_limit": image_file_size_limit, 'image_file_size_limit': image_file_size_limit
}, 200 }, 200
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(file_fields) @marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents") @cloud_edition_billing_resource_check(resource='documents')
def post(self): def post(self):
# get file from request # get file from request
file = request.files["file"] file = request.files['file']
# check file # check file
if "file" not in request.files: if 'file' not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()
if len(request.files) > 1: if len(request.files) > 1:
@ -67,7 +69,7 @@ class FilePreviewApi(Resource):
def get(self, file_id): def get(self, file_id):
file_id = str(file_id) file_id = str(file_id)
text = FileService.get_file_preview(file_id) text = FileService.get_file_preview(file_id)
return {"content": text} return {'content': text}
class FileSupportTypeApi(Resource): class FileSupportTypeApi(Resource):
@ -76,10 +78,10 @@ class FileSupportTypeApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
etl_type = dify_config.ETL_TYPE etl_type = dify_config.ETL_TYPE
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
return {"allowed_extensions": allowed_extensions} return {'allowed_extensions': allowed_extensions}
api.add_resource(FileApi, "/files/upload") api.add_resource(FileApi, '/files/upload')
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview") api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview')
api.add_resource(FileSupportTypeApi, "/files/support-type") api.add_resource(FileSupportTypeApi, '/files/support-type')

View File

@ -29,6 +29,7 @@ from services.hit_testing_service import HitTestingService
class HitTestingApi(Resource): class HitTestingApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -45,8 +46,8 @@ class HitTestingApi(Resource):
raise Forbidden(str(e)) raise Forbidden(str(e))
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json") parser.add_argument('query', type=str, location='json')
parser.add_argument("retrieval_model", type=dict, required=False, location="json") parser.add_argument('retrieval_model', type=dict, required=False, location='json')
args = parser.parse_args() args = parser.parse_args()
HitTestingService.hit_testing_args_check(args) HitTestingService.hit_testing_args_check(args)
@ -54,13 +55,13 @@ class HitTestingApi(Resource):
try: try:
response = HitTestingService.retrieve( response = HitTestingService.retrieve(
dataset=dataset, dataset=dataset,
query=args["query"], query=args['query'],
account=current_user, account=current_user,
retrieval_model=args["retrieval_model"], retrieval_model=args['retrieval_model'],
limit=10, limit=10
) )
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError: except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError() raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
@ -72,8 +73,7 @@ class HitTestingApi(Resource):
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider " "No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider." "in the Settings -> Model Provider.")
)
except InvokeError as e: except InvokeError as e:
raise CompletionRequestError(e.description) raise CompletionRequestError(e.description)
except ValueError as e: except ValueError as e:
@ -83,4 +83,4 @@ class HitTestingApi(Resource):
raise InternalServerError(str(e)) raise InternalServerError(str(e))
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing')

View File

@ -9,14 +9,16 @@ from services.website_service import WebsiteService
class WebsiteCrawlApi(Resource): class WebsiteCrawlApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json") parser.add_argument('provider', type=str, choices=['firecrawl'],
parser.add_argument("url", type=str, required=True, nullable=True, location="json") required=True, nullable=True, location='json')
parser.add_argument("options", type=dict, required=True, nullable=True, location="json") parser.add_argument('url', type=str, required=True, nullable=True, location='json')
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
args = parser.parse_args() args = parser.parse_args()
WebsiteService.document_create_args_validate(args) WebsiteService.document_create_args_validate(args)
# crawl url # crawl url
@ -33,15 +35,15 @@ class WebsiteCrawlStatusApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, job_id: str): def get(self, job_id: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args") parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
args = parser.parse_args() args = parser.parse_args()
# get crawl status # get crawl status
try: try:
result = WebsiteService.get_crawl_status(job_id, args["provider"]) result = WebsiteService.get_crawl_status(job_id, args['provider'])
except Exception as e: except Exception as e:
raise WebsiteCrawlError(str(e)) raise WebsiteCrawlError(str(e))
return result, 200 return result, 200
api.add_resource(WebsiteCrawlApi, "/website/crawl") api.add_resource(WebsiteCrawlApi, '/website/crawl')
api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/<string:job_id>") api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')

View File

@ -2,41 +2,35 @@ from libs.exception import BaseHTTPException
class AlreadySetupError(BaseHTTPException): class AlreadySetupError(BaseHTTPException):
error_code = "already_setup" error_code = 'already_setup'
description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage." description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage."
code = 403 code = 403
class NotSetupError(BaseHTTPException): class NotSetupError(BaseHTTPException):
error_code = "not_setup" error_code = 'not_setup'
description = ( description = "Dify has not been initialized and installed yet. " \
"Dify has not been initialized and installed yet. " "Please proceed with the initialization and installation process first."
"Please proceed with the initialization and installation process first."
)
code = 401 code = 401
class NotInitValidateError(BaseHTTPException): class NotInitValidateError(BaseHTTPException):
error_code = "not_init_validated" error_code = 'not_init_validated'
description = ( description = "Init validation has not been completed yet. " \
"Init validation has not been completed yet. " "Please proceed with the init validation process first." "Please proceed with the init validation process first."
)
code = 401 code = 401
class InitValidateFailedError(BaseHTTPException): class InitValidateFailedError(BaseHTTPException):
error_code = "init_validate_failed" error_code = 'init_validate_failed'
description = "Init validation failed. Please check the password and try again." description = "Init validation failed. Please check the password and try again."
code = 401 code = 401
class AccountNotLinkTenantError(BaseHTTPException): class AccountNotLinkTenantError(BaseHTTPException):
error_code = "account_not_link_tenant" error_code = 'account_not_link_tenant'
description = "Account not link tenant." description = "Account not link tenant."
code = 403 code = 403
class AlreadyActivateError(BaseHTTPException): class AlreadyActivateError(BaseHTTPException):
error_code = "already_activate" error_code = 'already_activate'
description = "Auth Token is invalid or account already activated, please check again." description = "Auth Token is invalid or account already activated, please check again."
code = 403 code = 403

View File

@ -33,10 +33,14 @@ class ChatAudioApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
file = request.files["file"] file = request.files['file']
try: try:
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) response = AudioService.transcript_asr(
app_model=app_model,
file=file,
end_user=None
)
return response return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
@ -72,27 +76,30 @@ class ChatTextApi(InstalledAppResource):
app_model = installed_app.app app_model = installed_app.app
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json") parser.add_argument('message_id', type=str, required=False, location='json')
parser.add_argument("voice", type=str, location="json") parser.add_argument('voice', type=str, location='json')
parser.add_argument("text", type=str, location="json") parser.add_argument('text', type=str, location='json')
parser.add_argument("streaming", type=bool, location="json") parser.add_argument('streaming', type=bool, location='json')
args = parser.parse_args() args = parser.parse_args()
message_id = args.get("message_id", None) message_id = args.get('message_id', None)
text = args.get("text", None) text = args.get('text', None)
if ( if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]
app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] and app_model.workflow
and app_model.workflow and app_model.workflow.features_dict):
and app_model.workflow.features_dict text_to_speech = app_model.workflow.features_dict.get('text_to_speech')
): voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice')
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice')
except Exception: except Exception:
voice = None voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) response = AudioService.transcript_tts(
app_model=app_model,
message_id=message_id,
voice=voice,
text=text
)
return response return response
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
@ -120,7 +127,7 @@ class ChatTextApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio") api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio')
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text") api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text')
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id', # api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
# endpoint='installed_app_text_with_message_id') # endpoint='installed_app_text_with_message_id')

View File

@ -30,28 +30,33 @@ from services.app_generate_service import AppGenerateService
# define completion api for user # define completion api for user
class CompletionApi(InstalledAppResource): class CompletionApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument("query", type=str, location="json", default="") parser.add_argument('query', type=str, location='json', default='')
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args['response_mode'] == 'streaming'
args["auto_generate_name"] = False args['auto_generate_name'] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -80,12 +85,12 @@ class CompletionApi(InstalledAppResource):
class CompletionStopApi(InstalledAppResource): class CompletionStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200 return {'result': 'success'}, 200
class ChatApi(InstalledAppResource): class ChatApi(InstalledAppResource):
@ -96,21 +101,25 @@ class ChatApi(InstalledAppResource):
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json") parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument("query", type=str, required=True, location="json") parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args() args = parser.parse_args()
args["auto_generate_name"] = False args['auto_generate_name'] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.EXPLORE,
streaming=True
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -145,22 +154,10 @@ class ChatStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200 return {'result': 'success'}, 200
api.add_resource( api.add_resource(CompletionApi, '/installed-apps/<uuid:installed_app_id>/completion-messages', endpoint='installed_app_completion')
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion" api.add_resource(CompletionStopApi, '/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop', endpoint='installed_app_stop_completion')
) api.add_resource(ChatApi, '/installed-apps/<uuid:installed_app_id>/chat-messages', endpoint='installed_app_chat_completion')
api.add_resource( api.add_resource(ChatStopApi, '/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop', endpoint='installed_app_stop_chat_completion')
CompletionStopApi,
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
api.add_resource(
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
)
api.add_resource(
ChatStopApi,
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)

View File

@ -16,6 +16,7 @@ from services.web_conversation_service import WebConversationService
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
@ -24,21 +25,21 @@ class ConversationListApi(InstalledAppResource):
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args')
args = parser.parse_args() args = parser.parse_args()
pinned = None pinned = None
if "pinned" in args and args["pinned"] is not None: if 'pinned' in args and args['pinned'] is not None:
pinned = True if args["pinned"] == "true" else False pinned = True if args['pinned'] == 'true' else False
try: try:
return WebConversationService.pagination_by_last_id( return WebConversationService.pagination_by_last_id(
app_model=app_model, app_model=app_model,
user=current_user, user=current_user,
last_id=args["last_id"], last_id=args['last_id'],
limit=args["limit"], limit=args['limit'],
invoke_from=InvokeFrom.EXPLORE, invoke_from=InvokeFrom.EXPLORE,
pinned=pinned, pinned=pinned,
) )
@ -64,6 +65,7 @@ class ConversationApi(InstalledAppResource):
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id): def post(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
@ -74,19 +76,24 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id) conversation_id = str(c_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, location="json") parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
return ConversationService.rename( return ConversationService.rename(
app_model, conversation_id, current_user, args["name"], args["auto_generate"] app_model,
conversation_id,
current_user,
args['name'],
args['auto_generate']
) )
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
class ConversationPinApi(InstalledAppResource): class ConversationPinApi(InstalledAppResource):
def patch(self, installed_app, c_id): def patch(self, installed_app, c_id):
app_model = installed_app.app app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode) app_mode = AppMode.value_of(app_model.mode)
@ -116,26 +123,8 @@ class ConversationUnPinApi(InstalledAppResource):
return {"result": "success"} return {"result": "success"}
api.add_resource( api.add_resource(ConversationRenameApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name', endpoint='installed_app_conversation_rename')
ConversationRenameApi, api.add_resource(ConversationListApi, '/installed-apps/<uuid:installed_app_id>/conversations', endpoint='installed_app_conversations')
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name", api.add_resource(ConversationApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>', endpoint='installed_app_conversation')
endpoint="installed_app_conversation_rename", api.add_resource(ConversationPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin', endpoint='installed_app_conversation_pin')
) api.add_resource(ConversationUnPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin', endpoint='installed_app_conversation_unpin')
api.add_resource(
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
)
api.add_resource(
ConversationApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
api.add_resource(
ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
api.add_resource(
ConversationUnPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)

View File

@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException
class NotCompletionAppError(BaseHTTPException): class NotCompletionAppError(BaseHTTPException):
error_code = "not_completion_app" error_code = 'not_completion_app'
description = "Not Completion App" description = "Not Completion App"
code = 400 code = 400
class NotChatAppError(BaseHTTPException): class NotChatAppError(BaseHTTPException):
error_code = "not_chat_app" error_code = 'not_chat_app'
description = "App mode is invalid." description = "App mode is invalid."
code = 400 code = 400
class NotWorkflowAppError(BaseHTTPException): class NotWorkflowAppError(BaseHTTPException):
error_code = "not_workflow_app" error_code = 'not_workflow_app'
description = "Only support workflow app." description = "Only support workflow app."
code = 400 code = 400
class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException):
error_code = "app_suggested_questions_after_answer_disabled" error_code = 'app_suggested_questions_after_answer_disabled'
description = "Function Suggested questions after answer disabled." description = "Function Suggested questions after answer disabled."
code = 403 code = 403

View File

@ -21,72 +21,72 @@ class InstalledAppsListApi(Resource):
@marshal_with(installed_app_list_fields) @marshal_with(installed_app_list_fields)
def get(self): def get(self):
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() installed_apps = db.session.query(InstalledApp).filter(
InstalledApp.tenant_id == current_tenant_id
).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_apps = [ installed_apps = [
{ {
"id": installed_app.id, 'id': installed_app.id,
"app": installed_app.app, 'app': installed_app.app,
"app_owner_tenant_id": installed_app.app_owner_tenant_id, 'app_owner_tenant_id': installed_app.app_owner_tenant_id,
"is_pinned": installed_app.is_pinned, 'is_pinned': installed_app.is_pinned,
"last_used_at": installed_app.last_used_at, 'last_used_at': installed_app.last_used_at,
"editable": current_user.role in ["owner", "admin"], 'editable': current_user.role in ["owner", "admin"],
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id
} }
for installed_app in installed_apps for installed_app in installed_apps
if installed_app.app is not None
] ]
installed_apps.sort( installed_apps.sort(key=lambda app: (-app['is_pinned'],
key=lambda app: ( app['last_used_at'] is None,
-app["is_pinned"], -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
app["last_used_at"] is None,
-app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0,
)
)
return {"installed_apps": installed_apps} return {'installed_apps': installed_apps}
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check('apps')
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") parser.add_argument('app_id', type=str, required=True, help='Invalid app_id')
args = parser.parse_args() args = parser.parse_args()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first()
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound('App not found')
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).filter(App.id == args["app_id"]).first() app = db.session.query(App).filter(
App.id == args['app_id']
).first()
if app is None: if app is None:
raise NotFound("App not found") raise NotFound('App not found')
if not app.is_public: if not app.is_public:
raise Forbidden("You can't install a non-public app") raise Forbidden('You can\'t install a non-public app')
installed_app = InstalledApp.query.filter( installed_app = InstalledApp.query.filter(and_(
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) InstalledApp.app_id == args['app_id'],
).first() InstalledApp.tenant_id == current_tenant_id
)).first()
if installed_app is None: if installed_app is None:
# todo: position # todo: position
recommended_app.install_count += 1 recommended_app.install_count += 1
new_installed_app = InstalledApp( new_installed_app = InstalledApp(
app_id=args["app_id"], app_id=args['app_id'],
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id, app_owner_tenant_id=app.tenant_id,
is_pinned=False, is_pinned=False,
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), last_used_at=datetime.now(timezone.utc).replace(tzinfo=None)
) )
db.session.add(new_installed_app) db.session.add(new_installed_app)
db.session.commit() db.session.commit()
return {"message": "App installed successfully"} return {'message': 'App installed successfully'}
class InstalledAppApi(InstalledAppResource): class InstalledAppApi(InstalledAppResource):
@ -94,31 +94,30 @@ class InstalledAppApi(InstalledAppResource):
update and delete an installed app update and delete an installed app
use InstalledAppResource to apply default decorators and get installed_app use InstalledAppResource to apply default decorators and get installed_app
""" """
def delete(self, installed_app): def delete(self, installed_app):
if installed_app.app_owner_tenant_id == current_user.current_tenant_id: if installed_app.app_owner_tenant_id == current_user.current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant") raise BadRequest('You can\'t uninstall an app owned by the current tenant')
db.session.delete(installed_app) db.session.delete(installed_app)
db.session.commit() db.session.commit()
return {"result": "success", "message": "App uninstalled successfully"} return {'result': 'success', 'message': 'App uninstalled successfully'}
def patch(self, installed_app): def patch(self, installed_app):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("is_pinned", type=inputs.boolean) parser.add_argument('is_pinned', type=inputs.boolean)
args = parser.parse_args() args = parser.parse_args()
commit_args = False commit_args = False
if "is_pinned" in args: if 'is_pinned' in args:
installed_app.is_pinned = args["is_pinned"] installed_app.is_pinned = args['is_pinned']
commit_args = True commit_args = True
if commit_args: if commit_args:
db.session.commit() db.session.commit()
return {"result": "success", "message": "App info updated successfully"} return {'result': 'success', 'message': 'App info updated successfully'}
api.add_resource(InstalledAppsListApi, "/installed-apps") api.add_resource(InstalledAppsListApi, '/installed-apps')
api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>") api.add_resource(InstalledAppApi, '/installed-apps/<uuid:installed_app_id>')

View File

@ -44,21 +44,19 @@ class MessageListApi(InstalledAppResource):
raise NotChatAppError() raise NotChatAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument("first_id", type=uuid_value, location="args") parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args() args = parser.parse_args()
try: try:
return MessageService.pagination_by_first_id( return MessageService.pagination_by_first_id(app_model, current_user,
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] args['conversation_id'], args['first_id'], args['limit'])
)
except services.errors.conversation.ConversationNotExistsError: except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
except services.errors.message.FirstMessageNotExistsError: except services.errors.message.FirstMessageNotExistsError:
raise NotFound("First Message Not Exists.") raise NotFound("First Message Not Exists.")
class MessageFeedbackApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource):
def post(self, installed_app, message_id): def post(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
@ -66,32 +64,30 @@ class MessageFeedbackApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) MessageService.create_feedback(app_model, message_id, current_user, args['rating'])
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
return {"result": "success"} return {'result': 'success'}
class MessageMoreLikeThisApi(InstalledAppResource): class MessageMoreLikeThisApi(InstalledAppResource):
def get(self, installed_app, message_id): def get(self, installed_app, message_id):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
message_id = str(message_id) message_id = str(message_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
"response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args"
)
args = parser.parse_args() args = parser.parse_args()
streaming = args["response_mode"] == "streaming" streaming = args['response_mode'] == 'streaming'
try: try:
response = AppGenerateService.generate_more_like_this( response = AppGenerateService.generate_more_like_this(
@ -99,7 +95,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
user=current_user, user=current_user,
message_id=message_id, message_id=message_id,
invoke_from=InvokeFrom.EXPLORE, invoke_from=InvokeFrom.EXPLORE,
streaming=streaming, streaming=streaming
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
except MessageNotExistsError: except MessageNotExistsError:
@ -132,7 +128,10 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
try: try:
questions = MessageService.get_suggested_questions_after_answer( questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE app_model=app_model,
user=current_user,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE
) )
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message not found") raise NotFound("Message not found")
@ -152,22 +151,10 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
return {"data": questions} return {'data': questions}
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages") api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages')
api.add_resource( api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback')
MessageFeedbackApi, api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this')
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks", api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question')
endpoint="installed_app_message_feedback",
)
api.add_resource(
MessageMoreLikeThisApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
api.add_resource(
MessageSuggestedQuestionApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)

View File

@ -1,3 +1,4 @@
from flask_restful import fields, marshal_with from flask_restful import fields, marshal_with
from configs import dify_config from configs import dify_config
@ -10,32 +11,33 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource): class AppParameterApi(InstalledAppResource):
"""Resource for app variables.""" """Resource for app variables."""
variable_fields = { variable_fields = {
"key": fields.String, 'key': fields.String,
"name": fields.String, 'name': fields.String,
"description": fields.String, 'description': fields.String,
"type": fields.String, 'type': fields.String,
"default": fields.String, 'default': fields.String,
"max_length": fields.Integer, 'max_length': fields.Integer,
"options": fields.List(fields.String), 'options': fields.List(fields.String)
} }
system_parameters_fields = {"image_file_size_limit": fields.String} system_parameters_fields = {
'image_file_size_limit': fields.String
}
parameters_fields = { parameters_fields = {
"opening_statement": fields.String, 'opening_statement': fields.String,
"suggested_questions": fields.Raw, 'suggested_questions': fields.Raw,
"suggested_questions_after_answer": fields.Raw, 'suggested_questions_after_answer': fields.Raw,
"speech_to_text": fields.Raw, 'speech_to_text': fields.Raw,
"text_to_speech": fields.Raw, 'text_to_speech': fields.Raw,
"retriever_resource": fields.Raw, 'retriever_resource': fields.Raw,
"annotation_reply": fields.Raw, 'annotation_reply': fields.Raw,
"more_like_this": fields.Raw, 'more_like_this': fields.Raw,
"user_input_form": fields.Raw, 'user_input_form': fields.Raw,
"sensitive_word_avoidance": fields.Raw, 'sensitive_word_avoidance': fields.Raw,
"file_upload": fields.Raw, 'file_upload': fields.Raw,
"system_parameters": fields.Nested(system_parameters_fields), 'system_parameters': fields.Nested(system_parameters_fields)
} }
@marshal_with(parameters_fields) @marshal_with(parameters_fields)
@ -54,35 +56,30 @@ class AppParameterApi(InstalledAppResource):
app_model_config = app_model.app_model_config app_model_config = app_model.app_model_config
features_dict = app_model_config.to_dict() features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get('user_input_form', [])
return { return {
"opening_statement": features_dict.get("opening_statement"), 'opening_statement': features_dict.get('opening_statement'),
"suggested_questions": features_dict.get("suggested_questions", []), 'suggested_questions': features_dict.get('suggested_questions', []),
"suggested_questions_after_answer": features_dict.get( 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer',
"suggested_questions_after_answer", {"enabled": False} {"enabled": False}),
), 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), 'more_like_this': features_dict.get('more_like_this', {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}), 'user_input_form': user_input_form,
"user_input_form": user_input_form, 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance',
"sensitive_word_avoidance": features_dict.get( {"enabled": False, "type": "", "configs": []}),
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} 'file_upload': features_dict.get('file_upload', {"image": {
), "enabled": False,
"file_upload": features_dict.get( "number_limits": 3,
"file_upload", "detail": "high",
{ "transfer_methods": ["remote_url", "local_file"]
"image": { }}),
"enabled": False, 'system_parameters': {
"number_limits": 3, 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
"detail": "high", }
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT},
} }
@ -93,7 +90,6 @@ class ExploreAppMetaApi(InstalledAppResource):
return AppService().get_app_meta(app_model) return AppService().get_app_meta(app_model)
api.add_resource( api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters',
AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters" endpoint='installed_app_parameters')
) api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta')
api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@ -8,28 +8,28 @@ from libs.login import login_required
from services.recommended_app_service import RecommendedAppService from services.recommended_app_service import RecommendedAppService
app_fields = { app_fields = {
"id": fields.String, 'id': fields.String,
"name": fields.String, 'name': fields.String,
"mode": fields.String, 'mode': fields.String,
"icon": fields.String, 'icon': fields.String,
"icon_background": fields.String, 'icon_background': fields.String
} }
recommended_app_fields = { recommended_app_fields = {
"app": fields.Nested(app_fields, attribute="app"), 'app': fields.Nested(app_fields, attribute='app'),
"app_id": fields.String, 'app_id': fields.String,
"description": fields.String(attribute="description"), 'description': fields.String(attribute='description'),
"copyright": fields.String, 'copyright': fields.String,
"privacy_policy": fields.String, 'privacy_policy': fields.String,
"custom_disclaimer": fields.String, 'custom_disclaimer': fields.String,
"category": fields.String, 'category': fields.String,
"position": fields.Integer, 'position': fields.Integer,
"is_listed": fields.Boolean, 'is_listed': fields.Boolean
} }
recommended_app_list_fields = { recommended_app_list_fields = {
"recommended_apps": fields.List(fields.Nested(recommended_app_fields)), 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)),
"categories": fields.List(fields.String), 'categories': fields.List(fields.String)
} }
@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource):
def get(self): def get(self):
# language args # language args
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("language", type=str, location="args") parser.add_argument('language', type=str, location='args')
args = parser.parse_args() args = parser.parse_args()
if args.get("language") and args.get("language") in languages: if args.get('language') and args.get('language') in languages:
language_prefix = args.get("language") language_prefix = args.get('language')
elif current_user and current_user.interface_language: elif current_user and current_user.interface_language:
language_prefix = current_user.interface_language language_prefix = current_user.interface_language
else: else:
@ -61,5 +61,5 @@ class RecommendedAppApi(Resource):
return RecommendedAppService.get_recommend_app_detail(app_id) return RecommendedAppService.get_recommend_app_detail(app_id)
api.add_resource(RecommendedAppListApi, "/explore/apps") api.add_resource(RecommendedAppListApi, '/explore/apps')
api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>") api.add_resource(RecommendedAppApi, '/explore/apps/<uuid:app_id>')

View File

@ -11,54 +11,56 @@ from libs.helper import TimestampField, uuid_value
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService from services.saved_message_service import SavedMessageService
feedback_fields = {"rating": fields.String} feedback_fields = {
'rating': fields.String
}
message_fields = { message_fields = {
"id": fields.String, 'id': fields.String,
"inputs": fields.Raw, 'inputs': fields.Raw,
"query": fields.String, 'query': fields.String,
"answer": fields.String, 'answer': fields.String,
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
"created_at": TimestampField, 'created_at': TimestampField
} }
class SavedMessageListApi(InstalledAppResource): class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = { saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer, 'limit': fields.Integer,
"has_more": fields.Boolean, 'has_more': fields.Boolean,
"data": fields.List(fields.Nested(message_fields)), 'data': fields.List(fields.Nested(message_fields))
} }
@marshal_with(saved_message_infinite_scroll_pagination_fields) @marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, installed_app): def get(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("last_id", type=uuid_value, location="args") parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
args = parser.parse_args() args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit'])
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != "completion": if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("message_id", type=uuid_value, required=True, location="json") parser.add_argument('message_id', type=uuid_value, required=True, location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
SavedMessageService.save(app_model, current_user, args["message_id"]) SavedMessageService.save(app_model, current_user, args['message_id'])
except MessageNotExistsError: except MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
return {"result": "success"} return {'result': 'success'}
class SavedMessageApi(InstalledAppResource): class SavedMessageApi(InstalledAppResource):
@ -67,21 +69,13 @@ class SavedMessageApi(InstalledAppResource):
message_id = str(message_id) message_id = str(message_id)
if app_model.mode != "completion": if app_model.mode != 'completion':
raise NotCompletionAppError() raise NotCompletionAppError()
SavedMessageService.delete(app_model, current_user, message_id) SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"} return {'result': 'success'}
api.add_resource( api.add_resource(SavedMessageListApi, '/installed-apps/<uuid:installed_app_id>/saved-messages', endpoint='installed_app_saved_messages')
SavedMessageListApi, api.add_resource(SavedMessageApi, '/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>', endpoint='installed_app_saved_message')
"/installed-apps/<uuid:installed_app_id>/saved-messages",
endpoint="installed_app_saved_messages",
)
api.add_resource(
SavedMessageApi,
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>",
endpoint="installed_app_saved_message",
)

View File

@ -35,13 +35,17 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
raise NotWorkflowAppError() raise NotWorkflowAppError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json')
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument('files', type=list, required=False, location='json')
args = parser.parse_args() args = parser.parse_args()
try: try:
response = AppGenerateService.generate( response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True app_model=app_model,
user=current_user,
args=args,
invoke_from=InvokeFrom.EXPLORE,
streaming=True
) )
return helper.compact_generate_response(response) return helper.compact_generate_response(response)
@ -72,10 +76,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"} return {
"result": "success"
}
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run") api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps/<uuid:installed_app_id>/workflows/run')
api.add_resource( api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop')
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
)

View File

@ -14,33 +14,29 @@ def installed_app_required(view=None):
def decorator(view): def decorator(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
if not kwargs.get("installed_app_id"): if not kwargs.get('installed_app_id'):
raise ValueError("missing installed_app_id in path parameters") raise ValueError('missing installed_app_id in path parameters')
installed_app_id = kwargs.get("installed_app_id") installed_app_id = kwargs.get('installed_app_id')
installed_app_id = str(installed_app_id) installed_app_id = str(installed_app_id)
del kwargs["installed_app_id"] del kwargs['installed_app_id']
installed_app = ( installed_app = db.session.query(InstalledApp).filter(
db.session.query(InstalledApp) InstalledApp.id == str(installed_app_id),
.filter( InstalledApp.tenant_id == current_user.current_tenant_id
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id ).first()
)
.first()
)
if installed_app is None: if installed_app is None:
raise NotFound("Installed app not found") raise NotFound('Installed app not found')
if not installed_app.app: if not installed_app.app:
db.session.delete(installed_app) db.session.delete(installed_app)
db.session.commit() db.session.commit()
raise NotFound("Installed app not found") raise NotFound('Installed app not found')
return view(installed_app, *args, **kwargs) return view(installed_app, *args, **kwargs)
return decorated return decorated
if view: if view:

Some files were not shown because too many files have changed in this diff Show More