mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 04:46:38 +08:00
Compare commits
101 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 34bfb715e1 | |||
| 019d7069f8 | |||
| c54fcfb45d | |||
| cde87cb225 | |||
| 12435774ca | |||
| 80b9507e7a | |||
| 0ac0f0ffd0 | |||
| 3d14aba4b4 | |||
| 64f694865c | |||
| d36b728088 | |||
| 1a7b4c42ab | |||
| 2a64ce740e | |||
| 78988ed60e | |||
| 2832adda88 | |||
| a4e4fb4094 | |||
| 777ec64635 | |||
| 9cec8c1750 | |||
| 8ca5aa1190 | |||
| 4d8f1b9ca4 | |||
| 3da179f77b | |||
| a34e8cb0bd | |||
| b249767c5c | |||
| 89a7434565 | |||
| 3b537cbdeb | |||
| 731464f5b8 | |||
| 1ad70f8721 | |||
| 2ea8c73cd8 | |||
| f257f2c396 | |||
| 3cd8e6f5c6 | |||
| 0715db7681 | |||
| a39de8a686 | |||
| ccaf335466 | |||
| 40e36e9b52 | |||
| 9eebe9d54e | |||
| a23a191615 | |||
| 7d9c5586f9 | |||
| f07c89bba4 | |||
| 59cba930e5 | |||
| 39ae56e136 | |||
| f92130338b | |||
| 2867d29021 | |||
| f76ac8bdee | |||
| 83caffe000 | |||
| 96160837d2 | |||
| 3480f1c59e | |||
| 65ac4f69af | |||
| 2c50fab3dd | |||
| 9525ccac4f | |||
| ff76c4bd5d | |||
| 5dacf77627 | |||
| 2a213c6af7 | |||
| b2535e7db6 | |||
| 28236147ee | |||
| 4969783383 | |||
| b64080be1b | |||
| aadebd6d23 | |||
| 71cc0074ef | |||
| d77f52bf85 | |||
| b71163706b | |||
| 1fb7df12d7 | |||
| b3996b3221 | |||
| 7251748d59 | |||
| 73e9f35ab1 | |||
| d7f0056e2d | |||
| 9b7b133cbc | |||
| 7545e5de6c | |||
| a0c30702c1 | |||
| 03c988388e | |||
| 0a56c522eb | |||
| 646858ea08 | |||
| d9b821cecc | |||
| d5448e07ab | |||
| 3aa182e26a | |||
| de3b490f8e | |||
| 4481906be2 | |||
| d9f1a8ce9f | |||
| aa6d2e3035 | |||
| 4365843c20 | |||
| b4d2d635f7 | |||
| b9b28900b1 | |||
| d463b82aba | |||
| ed861ff782 | |||
| 8cc1944160 | |||
| 80e390b906 | |||
| c2acb2be60 | |||
| 8ba95c08a1 | |||
| c7de51ca9a | |||
| e02ee3bb2e | |||
| 394ceee141 | |||
| 40b48510f4 | |||
| be3b37114c | |||
| e212a87b86 | |||
| b890c11c14 | |||
| 2e27425e93 | |||
| 6269e011db | |||
| e70482dfc0 | |||
| 9b8861e3e1 | |||
| 38ca3b29b5 | |||
| 066076b157 | |||
| be27ac0e69 | |||
| 9e6d4eeb92 |
43
.github/workflows/api-tests.yml
vendored
43
.github/workflows/api-tests.yml
vendored
@ -8,6 +8,11 @@ on:
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||
@ -32,15 +37,31 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install APT packages
|
||||
uses: awalsh128/cache-apt-pkgs-action@v1
|
||||
- name: Set up Weaviate
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
packages: ffmpeg
|
||||
compose-file: docker/docker-compose.middleware.yaml
|
||||
services: weaviate
|
||||
|
||||
- name: Set up Python
|
||||
- name: Set up Qdrant
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: docker/docker-compose.qdrant.yaml
|
||||
services: qdrant
|
||||
|
||||
- name: Set up Milvus
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: docker/docker-compose.milvus.yaml
|
||||
services: |
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
./api/requirements.txt
|
||||
@ -49,11 +70,17 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pip install -r ./api/requirements.txt -r ./api/requirements-dev.txt
|
||||
|
||||
- name: Run Unit tests
|
||||
run: dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py
|
||||
run: dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run Tool
|
||||
run: pytest api/tests/integration_tests/tools/test_all_provider.py
|
||||
run: dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Run Workflow
|
||||
run: pytest api/tests/integration_tests/workflow
|
||||
run: dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Run Vector Stores
|
||||
run: dev/pytest/pytest_vdb.sh
|
||||
|
||||
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@ -24,11 +24,14 @@ jobs:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Python dependencies
|
||||
run: pip install ruff
|
||||
run: pip install ruff dotenv-linter
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check ./api
|
||||
|
||||
- name: Dotenv check
|
||||
run: dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
||||
14
README.md
14
README.md
@ -29,12 +29,12 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="Commits last month" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="Commits last month" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="Commits last month" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="Commits last month" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="Commits last month" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="Commits last month" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
#
|
||||
@ -54,7 +54,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
32
README_CN.md
32
README_CN.md
@ -44,11 +44,11 @@
|
||||
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | 趋势转变" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</div>
|
||||
|
||||
Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工作流程、RAG管道、代理功能、模型管理、可观察性功能等,让您可以快速从原型到生产。以下是其核心功能列表:
|
||||
Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI 工作流、RAG 管道、Agent、模型管理、可观测性功能等,让您可以快速从原型到生产。以下是其核心功能列表:
|
||||
</br> </br>
|
||||
|
||||
**1. 工作流**:
|
||||
在视觉画布上构建和测试功能强大的AI工作流程,利用以下所有功能以及更多功能。
|
||||
在画布上构建和测试功能强大的 AI 工作流程,利用以下所有功能以及更多功能。
|
||||
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
@ -56,7 +56,7 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工
|
||||
|
||||
|
||||
**2. 全面的模型支持**:
|
||||
与数百种专有/开源LLMs以及数十种推理提供商和自托管解决方案无缝集成,涵盖GPT、Mistral、Llama2以及任何与OpenAI API兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。
|
||||
与数百种专有/开源 LLMs 以及数十种推理提供商和自托管解决方案无缝集成,涵盖 GPT、Mistral、Llama3 以及任何与 OpenAI API 兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。
|
||||
|
||||

|
||||
|
||||
@ -65,16 +65,16 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工
|
||||
用于制作提示、比较模型性能以及向基于聊天的应用程序添加其他功能(如文本转语音)的直观界面。
|
||||
|
||||
**4. RAG Pipeline**:
|
||||
广泛的RAG功能,涵盖从文档摄入到检索的所有内容,支持从PDF、PPT和其他常见文档格式中提取文本的开箱即用的支持。
|
||||
广泛的 RAG 功能,涵盖从文档摄入到检索的所有内容,支持从 PDF、PPT 和其他常见文档格式中提取文本的开箱即用的支持。
|
||||
|
||||
**5. Agent 智能体**:
|
||||
您可以基于LLM函数调用或ReAct定义代理,并为代理添加预构建或自定义工具。Dify为AI代理提供了50多种内置工具,如谷歌搜索、DELL·E、稳定扩散和WolframAlpha等。
|
||||
您可以基于 LLM 函数调用或 ReAct 定义 Agent,并为 Agent 添加预构建或自定义工具。Dify 为 AI Agent 提供了50多种内置工具,如谷歌搜索、DELL·E、Stable Diffusion 和 WolframAlpha 等。
|
||||
|
||||
**6. LLMOps**:
|
||||
随时间监视和分析应用程序日志和性能。您可以根据生产数据和注释持续改进提示、数据集和模型。
|
||||
随时间监视和分析应用程序日志和性能。您可以根据生产数据和标注持续改进提示、数据集和模型。
|
||||
|
||||
**7. 后端即服务**:
|
||||
所有Dify的功能都带有相应的API,因此您可以轻松地将Dify集成到自己的业务逻辑中。
|
||||
所有 Dify 的功能都带有相应的 API,因此您可以轻松地将 Dify 集成到自己的业务逻辑中。
|
||||
|
||||
|
||||
## 功能比较
|
||||
@ -84,21 +84,21 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI助理API</th>
|
||||
<th align="center">OpenAI Assistant API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">编程方法</td>
|
||||
<td align="center">API + 应用程序导向</td>
|
||||
<td align="center">Python代码</td>
|
||||
<td align="center">Python 代码</td>
|
||||
<td align="center">应用程序导向</td>
|
||||
<td align="center">API导向</td>
|
||||
<td align="center">API 导向</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">支持的LLMs</td>
|
||||
<td align="center">支持的 LLMs</td>
|
||||
<td align="center">丰富多样</td>
|
||||
<td align="center">丰富多样</td>
|
||||
<td align="center">丰富多样</td>
|
||||
<td align="center">仅限OpenAI</td>
|
||||
<td align="center">仅限 OpenAI</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG引擎</td>
|
||||
@ -108,21 +108,21 @@ Dify 是一个开源的LLM应用开发平台。其直观的界面结合了AI工
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">代理</td>
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">工作流程</td>
|
||||
<td align="center">工作流</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">可观察性</td>
|
||||
<td align="center">可观测性</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
@ -202,7 +202,7 @@ docker compose up -d
|
||||
## Contributing
|
||||
|
||||
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
同时,请考虑通过社交媒体、活动和会议来支持Dify的分享。
|
||||
同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。
|
||||
|
||||
> 我们正在寻找贡献者来帮助将Dify翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md)获取更多信息,并在我们的[Discord社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto.
|
||||
|
||||
|
||||
**2. Soporte de modelos completo**:
|
||||
Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama2 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama3 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ Dify est une plateforme de développement d'applications LLM open source. Son in
|
||||
|
||||
|
||||
**2. Prise en charge complète des modèles**:
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama2, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
@ -55,9 +55,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
|
||||
|
||||
**2. 網羅的なモデルサポート**:
|
||||
数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama2、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs
|
||||
|
||||
.dify.ai/getting-started/readme/model-providers)をご覧ください。
|
||||
数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。
|
||||
|
||||

|
||||
|
||||
@ -155,9 +153,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
さらなる参照や詳細な手順については、[ドキュメント](https://docs.dify.ai)をご覧ください。
|
||||
|
||||
- **エンタープライズ/組織向けのDify</br>**
|
||||
追加のエンタープライズ向け機能を提供しています。[こちらからミーティ
|
||||
|
||||
ングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。 </br>
|
||||
追加のエンタープライズ向け機能を提供しています。[こちらからミーティングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。 </br>
|
||||
> AWSを使用しているスタートアップや中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで独自のAWS VPCにデプロイできます。カスタムロゴとブランディングでアプリを作成するオプションを備えた手頃な価格のAMIオファリングです。
|
||||
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
|
||||
@ -52,6 +52,11 @@ AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
AZURE_BLOB_CONTAINER_NAME=yout-container-name
|
||||
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
|
||||
# Aliyun oss Storage configuration
|
||||
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
|
||||
ALIYUN_OSS_ACCESS_KEY=your-access-key
|
||||
ALIYUN_OSS_SECRET_KEY=your-secret-key
|
||||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
@ -160,3 +165,6 @@ CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
# API Tool configuration
|
||||
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
|
||||
API_TOOL_DEFAULT_READ_TIMEOUT=60
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
@ -55,3 +55,16 @@
|
||||
9. If you need to debug local async processing, please start the worker service by running
|
||||
`celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail`.
|
||||
The started celery app handles the async tasks, e.g. dataset importing and documents indexing.
|
||||
|
||||
|
||||
## Testing
|
||||
|
||||
1. Install dependencies for both the backend and the test environment
|
||||
```bash
|
||||
pip install -r requirements.txt -r requirements-dev.txt
|
||||
```
|
||||
|
||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
```bash
|
||||
dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
37
api/app.py
37
api/app.py
@ -1,4 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
@ -17,10 +19,13 @@ import warnings
|
||||
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from commands import register_commands
|
||||
from config import CloudEditionConfig, Config
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
@ -37,11 +42,8 @@ from extensions import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
from models import account, dataset, model, source, task, tool, tools, web
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
@ -86,7 +88,25 @@ def create_app(test_config=None) -> Flask:
|
||||
|
||||
app.secret_key = app.config['SECRET_KEY']
|
||||
|
||||
logging.basicConfig(level=app.config.get('LOG_LEVEL', 'INFO'))
|
||||
log_handlers = None
|
||||
log_file = app.config.get('LOG_FILE')
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_handlers = [
|
||||
RotatingFileHandler(
|
||||
filename=log_file,
|
||||
maxBytes=1024 * 1024 * 1024,
|
||||
backupCount=5
|
||||
),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
logging.basicConfig(
|
||||
level=app.config.get('LOG_LEVEL'),
|
||||
format=app.config.get('LOG_FORMAT'),
|
||||
datefmt=app.config.get('LOG_DATEFORMAT'),
|
||||
handlers=log_handlers
|
||||
)
|
||||
|
||||
initialize_extensions(app)
|
||||
register_blueprints(app)
|
||||
@ -115,7 +135,7 @@ def initialize_extensions(app):
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint == 'console':
|
||||
if request.blueprint in ['console', 'inner_api']:
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get('Authorization', '')
|
||||
if not auth_header:
|
||||
@ -151,6 +171,7 @@ def unauthorized_handler():
|
||||
def register_blueprints(app):
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
@ -188,6 +209,8 @@ def register_blueprints(app):
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
app.register_blueprint(inner_api_bp)
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
|
||||
@ -38,6 +38,9 @@ DEFAULTS = {
|
||||
'QDRANT_CLIENT_TIMEOUT': 20,
|
||||
'CELERY_BACKEND': 'database',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'LOG_FILE': '',
|
||||
'LOG_FORMAT': '%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s',
|
||||
'LOG_DATEFORMAT': '%Y-%m-%d %H:%M:%S',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
|
||||
@ -69,6 +72,8 @@ DEFAULTS = {
|
||||
'TOOL_ICON_CACHE_MAX_AGE': 3600,
|
||||
'MILVUS_DATABASE': 'default',
|
||||
'KEYWORD_DATA_SOURCE_TYPE': 'database',
|
||||
'INNER_API': 'False',
|
||||
'ENTERPRISE_ENABLED': 'False',
|
||||
}
|
||||
|
||||
|
||||
@ -99,12 +104,15 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.6.3"
|
||||
self.CURRENT_VERSION = "0.6.5"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
self.LOG_FILE = get_env('LOG_FILE')
|
||||
self.LOG_FORMAT = get_env('LOG_FORMAT')
|
||||
self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT')
|
||||
|
||||
# The backend URL prefix of the console API.
|
||||
# used to concatenate the login authorization callback or notion integration callback.
|
||||
@ -133,6 +141,11 @@ class Config:
|
||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||
self.SECRET_KEY = get_env('SECRET_KEY')
|
||||
|
||||
# Enable or disable the inner API.
|
||||
self.INNER_API = get_bool_env('INNER_API')
|
||||
# The inner API key is used to authenticate the inner API.
|
||||
self.INNER_API_KEY = get_env('INNER_API_KEY')
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
@ -195,6 +208,10 @@ class Config:
|
||||
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
|
||||
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
|
||||
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
|
||||
self.ALIYUN_OSS_BUCKET_NAME=get_env('ALIYUN_OSS_BUCKET_NAME')
|
||||
self.ALIYUN_OSS_ACCESS_KEY=get_env('ALIYUN_OSS_ACCESS_KEY')
|
||||
self.ALIYUN_OSS_SECRET_KEY=get_env('ALIYUN_OSS_SECRET_KEY')
|
||||
self.ALIYUN_OSS_ENDPOINT=get_env('ALIYUN_OSS_ENDPOINT')
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
@ -327,6 +344,8 @@ class Config:
|
||||
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
|
||||
|
||||
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
|
||||
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,22 +1,60 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('console', __name__, url_prefix='/console/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, setup, version, ping
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
|
||||
# Import app controllers
|
||||
from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message,
|
||||
model_config, site, statistic, workflow, workflow_run, workflow_app_log, workflow_statistic, agent)
|
||||
from .app import (
|
||||
advanced_prompt_template,
|
||||
agent,
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
generator,
|
||||
message,
|
||||
model_config,
|
||||
site,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_oauth, login, oauth
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
|
||||
|
||||
# Import enterprise controllers
|
||||
from .enterprise import enterprise_sso
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (audio, completion, conversation, installed_app, message, parameter, recommended_app,
|
||||
saved_message, workflow)
|
||||
from .explore import (
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
installed_app,
|
||||
message,
|
||||
parameter,
|
||||
recommended_app,
|
||||
saved_message,
|
||||
workflow,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
|
||||
@ -1,26 +1,25 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, BadRequest
|
||||
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from extensions.ext_database import db
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
app_detail_fields_with_site,
|
||||
app_pagination_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_service import AppService
|
||||
from models.model import App, AppModelConfig, AppMode
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
from services.tag_service import TagService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
|
||||
@ -30,21 +29,29 @@ class AppListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_pagination_fields)
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(',')]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
|
||||
parser.add_argument('name', type=str, location='args', required=False)
|
||||
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
|
||||
|
||||
return app_pagination
|
||||
return marshal(app_pagination, app_pagination_fields)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -109,43 +116,9 @@ class AppApi(Resource):
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
# get original app model config
|
||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
model_config: AppModelConfig = app_model.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
app_service = AppService()
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
db.session.commit()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@ -57,6 +57,7 @@ class ModelConfigResource(Resource):
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
@ -64,6 +65,7 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
@ -94,6 +96,7 @@ class ModelConfigResource(Resource):
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -104,6 +107,7 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
@ -111,9 +115,11 @@ class ModelConfigResource(Resource):
|
||||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
|
||||
agent_tool_entity.tool_parameters = parameter_map[key]
|
||||
|
||||
for masked_key, masked_value in masked_parameter_map[key].items():
|
||||
if masked_key in agent_tool_entity.tool_parameters and \
|
||||
agent_tool_entity.tool_parameters[masked_key] == masked_value:
|
||||
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
|
||||
@ -26,10 +26,13 @@ class LoginApi(Resource):
|
||||
|
||||
try:
|
||||
account = AccountService.authenticate(args['email'], args['password'])
|
||||
except services.errors.account.AccountLoginError:
|
||||
return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401
|
||||
except services.errors.account.AccountLoginError as e:
|
||||
return {'code': 'unauthorized', 'message': str(e)}, 401
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'}
|
||||
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
|
||||
@ -48,11 +48,14 @@ class DatasetListApi(Resource):
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
ids = request.args.getlist('ids')
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user)
|
||||
current_user.current_tenant_id, current_user, search, tag_ids)
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
@ -184,6 +187,10 @@ class DatasetApi(Resource):
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -506,10 +513,27 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
else:
|
||||
raise ValueError("Unsupported vector db type.")
|
||||
|
||||
class DatasetErrorDocs(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': [marshal(item, document_status_fields) for item in results],
|
||||
'total': len(results)
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask import request
|
||||
@ -233,7 +234,7 @@ class DatasetDocumentListApi(Resource):
|
||||
location='json')
|
||||
parser.add_argument('data_source', type=dict, required=False, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
@ -883,6 +884,49 @@ class DocumentRecoverApi(DocumentResource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentRetryApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('document_ids', type=list, required=True, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
retry_documents = []
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
for document_id in args['document_ids']:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
# 400 if document is completed
|
||||
if document.indexing_status == 'completed':
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
logging.error(f"Document {document_id} retry failed: {str(e)}")
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
@ -908,3 +952,4 @@ api.add_resource(DocumentStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
|
||||
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
|
||||
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
|
||||
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
|
||||
|
||||
@ -12,7 +12,7 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.errors.error import (
|
||||
@ -45,10 +45,6 @@ class HitTestingApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# only high quality dataset can be used for hit testing
|
||||
if dataset.indexing_technique != 'high_quality':
|
||||
raise HighQualityDatasetOnlyError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, location='json')
|
||||
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
|
||||
|
||||
59
api/controllers/console/enterprise/enterprise_sso.py
Normal file
59
api/controllers/console/enterprise/enterprise_sso.py
Normal file
@ -0,0 +1,59 @@
|
||||
from flask import current_app, redirect
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from services.enterprise.enterprise_sso_service import EnterpriseSSOService
|
||||
|
||||
|
||||
class EnterpriseSSOSamlLogin(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
return EnterpriseSSOService.get_sso_saml_login()
|
||||
|
||||
|
||||
class EnterpriseSSOSamlAcs(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('SAMLResponse', type=str, required=True, location='form')
|
||||
args = parser.parse_args()
|
||||
saml_response = args['SAMLResponse']
|
||||
|
||||
try:
|
||||
token = EnterpriseSSOService.post_sso_saml_acs(saml_response)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
|
||||
except Exception as e:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
|
||||
|
||||
|
||||
class EnterpriseSSOOidcLogin(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
return EnterpriseSSOService.get_sso_oidc_login()
|
||||
|
||||
|
||||
class EnterpriseSSOOidcCallback(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('state', type=str, required=True, location='args')
|
||||
parser.add_argument('code', type=str, required=True, location='args')
|
||||
parser.add_argument('oidc-state', type=str, required=True, location='cookies')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
token = EnterpriseSSOService.get_sso_oidc_callback(args)
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?console_token={token}')
|
||||
except Exception as e:
|
||||
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}/signin?message={str(e)}')
|
||||
|
||||
|
||||
api.add_resource(EnterpriseSSOSamlLogin, '/enterprise/sso/saml/login')
|
||||
api.add_resource(EnterpriseSSOSamlAcs, '/enterprise/sso/saml/acs')
|
||||
api.add_resource(EnterpriseSSOOidcLogin, '/enterprise/sso/oidc/login')
|
||||
api.add_resource(EnterpriseSSOOidcCallback, '/enterprise/sso/oidc/callback')
|
||||
@ -1,6 +1,7 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
|
||||
from services.enterprise.enterprise_feature_service import EnterpriseFeatureService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api
|
||||
@ -14,4 +15,10 @@ class FeatureApi(Resource):
|
||||
return FeatureService.get_features(current_user.current_tenant_id).dict()
|
||||
|
||||
|
||||
class EnterpriseFeatureApi(Resource):
|
||||
def get(self):
|
||||
return EnterpriseFeatureService.get_enterprise_features().dict()
|
||||
|
||||
|
||||
api.add_resource(FeatureApi, '/features')
|
||||
api.add_resource(EnterpriseFeatureApi, '/enterprise-features')
|
||||
|
||||
@ -58,6 +58,8 @@ class SetupApi(Resource):
|
||||
password=args['password']
|
||||
)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
setup()
|
||||
AccountService.update_last_login(account, request)
|
||||
|
||||
|
||||
159
api/controllers/console/tag/tags.py
Normal file
159
api/controllers/console/tag/tags.py
Normal file
@ -0,0 +1,159 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import login_required
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError('Name must be between 1 to 50 characters.')
|
||||
return name
|
||||
|
||||
|
||||
class TagListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tag_fields)
|
||||
def get(self):
|
||||
tag_type = request.args.get('type', type=str)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': 0
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
args = parser.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': binding_count
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
class TagBindingCreateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
|
||||
help='Tag IDs is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
class TagBindingDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tag_id', type=str, nullable=False, required=True,
|
||||
help='Tag ID is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True,
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(TagListApi, '/tags')
|
||||
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
|
||||
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
|
||||
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')
|
||||
@ -9,7 +9,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
@ -43,7 +43,7 @@ class MemberInviteEmailApi(Resource):
|
||||
invitee_emails = args['emails']
|
||||
invitee_role = args['role']
|
||||
interface_language = args['language']
|
||||
if invitee_role not in ['admin', 'normal']:
|
||||
if invitee_role not in [TenantAccountRole.ADMIN, TenantAccountRole.NORMAL]:
|
||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||
|
||||
inviter = current_user
|
||||
|
||||
@ -11,6 +11,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import TenantAccountRole
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
@ -94,7 +95,7 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
@ -125,7 +126,7 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
@ -19,7 +20,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import login_required
|
||||
from models.account import Tenant
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.file_service import FileService
|
||||
from services.workspace_service import WorkspaceService
|
||||
@ -116,6 +117,16 @@ class TenantApi(Resource):
|
||||
|
||||
tenant = current_user.current_tenant
|
||||
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
# if there is any tenant, switch to the first one
|
||||
if len(tenants) > 0:
|
||||
TenantService.switch_tenant(current_user, tenants[0].id)
|
||||
tenant = tenants[0]
|
||||
# else, raise Unauthorized
|
||||
else:
|
||||
raise Unauthorized('workspace is archived')
|
||||
|
||||
return WorkspaceService.get_tenant_info(tenant), 200
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('files', __name__)
|
||||
|
||||
9
api/controllers/inner_api/__init__.py
Normal file
9
api/controllers/inner_api/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('inner_api', __name__, url_prefix='/inner/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from .workspace import workspace
|
||||
|
||||
0
api/controllers/inner_api/workspace/__init__.py
Normal file
0
api/controllers/inner_api/workspace/__init__.py
Normal file
37
api/controllers/inner_api/workspace/workspace.py
Normal file
37
api/controllers/inner_api/workspace/workspace.py
Normal file
@ -0,0 +1,37 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.inner_api import api
|
||||
from controllers.inner_api.wraps import inner_api_only
|
||||
from events.tenant_event import tenant_was_created
|
||||
from models.account import Account
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
class EnterpriseWorkspace(Resource):
|
||||
|
||||
@setup_required
|
||||
@inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('owner_email', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = Account.query.filter_by(email=args['owner_email']).first()
|
||||
if account is None:
|
||||
return {
|
||||
'message': 'owner account not found.'
|
||||
}, 404
|
||||
|
||||
tenant = TenantService.create_tenant(args['name'])
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
return {
|
||||
'message': 'enterprise workspace created.'
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EnterpriseWorkspace, '/enterprise/workspace')
|
||||
61
api/controllers/inner_api/wraps.py
Normal file
61
api/controllers/inner_api/wraps.py
Normal file
@ -0,0 +1,61 @@
|
||||
from base64 import b64encode
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
|
||||
from flask import abort, current_app, request
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||
if not inner_api_key or inner_api_key != current_app.config['INNER_API_KEY']:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def inner_api_user_auth(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not current_app.config['INNER_API']:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
authorization = request.headers.get('Authorization')
|
||||
if not authorization:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
parts = authorization.split(':')
|
||||
if len(parts) != 2:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
user_id, token = parts
|
||||
if ' ' in user_id:
|
||||
user_id = user_id.split(' ')[1]
|
||||
|
||||
inner_api_key = request.headers.get('X-Inner-Api-Key')
|
||||
|
||||
data_to_sign = f'DIFY {user_id}'
|
||||
|
||||
signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1)
|
||||
signature = b64encode(signature.digest()).decode('utf-8')
|
||||
|
||||
if signature != token:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('service_api', __name__, url_prefix='/v1')
|
||||
|
||||
@ -26,8 +26,11 @@ class DatasetApi(DatasetApiResource):
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
tenant_id, current_user)
|
||||
tenant_id, current_user, search, tag_ids)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
|
||||
@ -174,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
if not dataset:
|
||||
raise ValueError('Dataset is not exist.')
|
||||
if not dataset.indexing_technique and not args['indexing_technique']:
|
||||
if not dataset.indexing_technique and not args.get('indexing_technique'):
|
||||
raise ValueError('indexing_technique is required.')
|
||||
|
||||
# save file info
|
||||
|
||||
@ -12,7 +12,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -47,6 +47,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
if not app_model.enable_api:
|
||||
raise NotFound()
|
||||
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise NotFound()
|
||||
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
if fetch_user_arg:
|
||||
@ -137,6 +141,7 @@ def validate_dataset_token(view=None):
|
||||
.filter(Tenant.id == api_token.tenant_id) \
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
|
||||
.filter(TenantAccountJoin.role.in_(['owner'])) \
|
||||
.filter(Tenant.status == TenantStatus.NORMAL) \
|
||||
.one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('web', __name__, url_prefix='/api')
|
||||
|
||||
@ -7,7 +7,7 @@ from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, AppMode
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from models.model import Site
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -54,6 +55,9 @@ class AppSiteApi(WebApiResource):
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
||||
|
||||
@ -163,6 +163,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
@ -18,7 +18,7 @@ from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -56,6 +56,14 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
@ -98,7 +106,8 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
@ -84,13 +84,19 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
|
||||
if isinstance(self._user, EndUser):
|
||||
user_id = self._user.session_id
|
||||
else:
|
||||
user_id = self._user.id
|
||||
|
||||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
|
||||
@ -26,7 +26,10 @@ class AppGenerateResponseConverter(ABC):
|
||||
else:
|
||||
def _generate():
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
yield f'data: {chunk}\n\n'
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
else:
|
||||
@ -35,7 +38,10 @@ class AppGenerateResponseConverter(ABC):
|
||||
else:
|
||||
def _generate():
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
yield f'data: {chunk}\n\n'
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
|
||||
return _generate()
|
||||
|
||||
|
||||
@ -23,20 +23,28 @@ class BaseAppGenerator:
|
||||
value = user_inputs[variable]
|
||||
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in input form must be a string")
|
||||
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
|
||||
if '.' in value:
|
||||
value = float(value)
|
||||
else:
|
||||
value = int(value)
|
||||
|
||||
if variable_config.type == VariableEntity.Type.SELECT:
|
||||
options = variable_config.options if variable_config.options is not None else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
else:
|
||||
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
|
||||
if variable_config.max_length is not None:
|
||||
max_length = variable_config.max_length
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
||||
if value and isinstance(value, str):
|
||||
filtered_inputs[variable] = value.replace('\x00', '')
|
||||
else:
|
||||
filtered_inputs[variable] = value if value else None
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -36,6 +36,14 @@ class WorkflowAppRunner:
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
@ -67,7 +75,8 @@ class WorkflowAppRunner:
|
||||
else UserFrom.END_USER,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: files
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
@ -71,9 +71,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
|
||||
if isinstance(self._user, EndUser):
|
||||
user_id = self._user.session_id
|
||||
else:
|
||||
user_id = self._user.id
|
||||
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
||||
@ -72,7 +72,7 @@ class AppGenerateEntity(BaseModel):
|
||||
# app config
|
||||
app_config: AppConfig
|
||||
|
||||
inputs: dict[str, str]
|
||||
inputs: dict[str, Any]
|
||||
files: list[FileVar] = []
|
||||
user_id: str
|
||||
|
||||
|
||||
@ -118,7 +118,8 @@ class MessageCycleManage:
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
|
||||
@ -84,7 +84,7 @@ class DatasetDocumentStore:
|
||||
if not isinstance(doc, Document):
|
||||
raise ValueError("doc must be a Document")
|
||||
|
||||
segment_document = self.get_document(doc_id=doc.metadata['doc_id'], raise_error=False)
|
||||
segment_document = self.get_document_segment(doc_id=doc.metadata['doc_id'])
|
||||
|
||||
# NOTE: doc could already exist in the store, but we overwrite it
|
||||
if not allow_update and segment_document:
|
||||
|
||||
@ -30,34 +30,24 @@ class CodeExecutionResponse(BaseModel):
|
||||
|
||||
class CodeExecutor:
|
||||
@classmethod
|
||||
def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
|
||||
def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], preload: str, code: str) -> str:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
:param code: code
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
template_transformer = None
|
||||
if language == 'python3':
|
||||
template_transformer = PythonTemplateTransformer
|
||||
elif language == 'jinja2':
|
||||
template_transformer = Jinja2TemplateTransformer
|
||||
elif language == 'javascript':
|
||||
template_transformer = NodeJsTemplateTransformer
|
||||
else:
|
||||
raise CodeExecutionException('Unsupported language')
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run'
|
||||
|
||||
headers = {
|
||||
'X-Api-Key': CODE_EXECUTION_API_KEY
|
||||
}
|
||||
|
||||
data = {
|
||||
'language': 'python3' if language == 'jinja2' else
|
||||
'nodejs' if language == 'javascript' else
|
||||
'python3' if language == 'python3' else None,
|
||||
'code': runner,
|
||||
'code': code,
|
||||
'preload': preload
|
||||
}
|
||||
|
||||
@ -85,4 +75,32 @@ class CodeExecutor:
|
||||
if response.data.error:
|
||||
raise CodeExecutionException(response.data.error)
|
||||
|
||||
return template_transformer.transform_response(response.data.stdout)
|
||||
return response.data.stdout
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
:param code: code
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
template_transformer = None
|
||||
if language == 'python3':
|
||||
template_transformer = PythonTemplateTransformer
|
||||
elif language == 'jinja2':
|
||||
template_transformer = Jinja2TemplateTransformer
|
||||
elif language == 'javascript':
|
||||
template_transformer = NodeJsTemplateTransformer
|
||||
else:
|
||||
raise CodeExecutionException('Unsupported language')
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
except CodeExecutionException as e:
|
||||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
||||
@ -1,10 +1,13 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """
|
||||
import jinja2
|
||||
from json import loads
|
||||
from base64 import b64decode
|
||||
|
||||
template = jinja2.Template('''{{code}}''')
|
||||
|
||||
@ -12,7 +15,8 @@ def main(**inputs):
|
||||
return template.render(**inputs)
|
||||
|
||||
# execute main function, and return the result
|
||||
output = main(**{{inputs}})
|
||||
inputs = b64decode('{{inputs}}').decode('utf-8')
|
||||
output = main(**loads(inputs))
|
||||
|
||||
result = f'''<<RESULT>>{output}<<RESULT>>'''
|
||||
|
||||
@ -39,6 +43,7 @@ JINJA2_PRELOAD_TEMPLATE = """{% set fruits = ['Apple'] %}
|
||||
|
||||
JINJA2_PRELOAD = f"""
|
||||
import jinja2
|
||||
from base64 import b64decode
|
||||
|
||||
def _jinja2_preload_():
|
||||
# prepare jinja2 environment, load template and render before to avoid sandbox issue
|
||||
@ -60,9 +65,11 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
:return:
|
||||
"""
|
||||
|
||||
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
|
||||
|
||||
# transform jinja2 template to python code
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4, ensure_ascii=False))
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, JINJA2_PRELOAD
|
||||
|
||||
|
||||
@ -1,17 +1,22 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """# declare main function here
|
||||
{{code}}
|
||||
|
||||
from json import loads, dumps
|
||||
from base64 import b64decode
|
||||
|
||||
# execute main function, and return the result
|
||||
# inputs is a dict, and it
|
||||
output = main(**{{inputs}})
|
||||
inputs = b64decode('{{inputs}}').decode('utf-8')
|
||||
output = main(**json.loads(inputs))
|
||||
|
||||
# convert output to json and print
|
||||
output = json.dumps(output, indent=4)
|
||||
output = dumps(output, indent=4)
|
||||
|
||||
result = f'''<<RESULT>>
|
||||
{output}
|
||||
@ -20,8 +25,28 @@ result = f'''<<RESULT>>
|
||||
print(result)
|
||||
"""
|
||||
|
||||
PYTHON_PRELOAD = """"""
|
||||
|
||||
PYTHON_PRELOAD = """
|
||||
# prepare general imports
|
||||
import json
|
||||
import datetime
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import binascii
|
||||
import collections
|
||||
import functools
|
||||
import operator
|
||||
import itertools
|
||||
"""
|
||||
|
||||
class PythonTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@ -34,7 +59,7 @@ class PythonTemplateTransformer(TemplateTransformer):
|
||||
"""
|
||||
|
||||
# transform inputs to json string
|
||||
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
|
||||
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
|
||||
|
||||
# replace code and inputs
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
|
||||
@ -11,12 +11,13 @@ class ToolParameterCacheType(Enum):
|
||||
|
||||
class ToolParameterCache:
|
||||
def __init__(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
cache_type: ToolParameterCacheType
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
cache_type: ToolParameterCacheType,
|
||||
identity_id: str
|
||||
):
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
|
||||
@ -88,6 +88,14 @@ class PromptMessage(ABC, BaseModel):
|
||||
content: Optional[str | list[PromptMessageContent]] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
return not self.content
|
||||
|
||||
|
||||
class UserPromptMessage(PromptMessage):
|
||||
"""
|
||||
@ -118,6 +126,16 @@ class AssistantPromptMessage(PromptMessage):
|
||||
role: PromptMessageRole = PromptMessageRole.ASSISTANT
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_calls:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
class SystemPromptMessage(PromptMessage):
|
||||
"""
|
||||
@ -132,3 +150,14 @@ class ToolPromptMessage(PromptMessage):
|
||||
"""
|
||||
role: PromptMessageRole = PromptMessageRole.TOOL
|
||||
tool_call_id: str
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
Check if prompt message is empty.
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -10,3 +10,6 @@
|
||||
- cohere.command-text-v14
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
- mistral.mistral-large-2402-v1:0
|
||||
- mistral.mixtral-8x7b-instruct-v0:1
|
||||
- mistral.mistral-7b-instruct-v0:2
|
||||
|
||||
@ -0,0 +1,57 @@
|
||||
model: anthropic.claude-3-opus-20240229-v1:0
|
||||
label:
|
||||
en_US: Claude 3 Opus
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.075'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -449,6 +449,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
human_prompt_prefix = "\n[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = ""
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
human_prompt_prefix = "<s>[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
|
||||
elif model_prefix == "amazon":
|
||||
human_prompt_prefix = "\n\nUser:"
|
||||
@ -519,6 +524,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
|
||||
if model_parameters.get("countPenalty"):
|
||||
payload["countPenalty"] = {model_parameters.get("countPenalty")}
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
payload["temperature"] = model_parameters.get("temperature")
|
||||
payload["top_p"] = model_parameters.get("top_p")
|
||||
payload["max_tokens"] = model_parameters.get("max_tokens")
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
||||
payload["stop"] = stop[:10] if stop else []
|
||||
|
||||
elif model_prefix == "anthropic":
|
||||
payload = { **model_parameters }
|
||||
@ -648,6 +660,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
output = response_body.get("generation").strip('\n')
|
||||
prompt_tokens = response_body.get("prompt_token_count")
|
||||
completion_tokens = response_body.get("generation_token_count")
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
output = response_body.get("outputs")[0].get("text")
|
||||
prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count')
|
||||
completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count')
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
@ -731,6 +748,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
content_delta = payload.get("text")
|
||||
finish_reason = payload.get("finish_reason")
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
content_delta = payload.get('outputs')[0].get("text")
|
||||
finish_reason = payload.get('outputs')[0].get("stop_reason")
|
||||
|
||||
elif model_prefix == "meta":
|
||||
content_delta = payload.get("generation").strip('\n')
|
||||
finish_reason = payload.get("stop_reason")
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
model: mistral.mistral-7b-instruct-v0:2
|
||||
label:
|
||||
en_US: Mistral 7B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 0.9
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 50
|
||||
max: 200
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00015'
|
||||
output: '0.0002'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
||||
@ -0,0 +1,27 @@
|
||||
model: mistral.mistral-large-2402-v1:0
|
||||
label:
|
||||
en_US: Mistral Large
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.7
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,39 @@
|
||||
model: mistral.mixtral-8x7b-instruct-v0:1
|
||||
label:
|
||||
en_US: Mixtral 8X7B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 0.9
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 50
|
||||
max: 200
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00045'
|
||||
output: '0.0007'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
||||
@ -19,7 +19,7 @@ class GroqProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='llama2-70b-4096',
|
||||
model='llama3-8b-8192',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
||||
@ -0,0 +1,25 @@
|
||||
model: llama3-70b-8192
|
||||
label:
|
||||
zh_Hans: Llama-3-70B-8192
|
||||
en_US: Llama-3-70B-8192
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,25 @@
|
||||
model: llama3-8b-8192
|
||||
label:
|
||||
zh_Hans: Llama-3-8B-8192
|
||||
en_US: Llama-3-8B-8192
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.59'
|
||||
output: '0.79'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,5 +1,6 @@
|
||||
- open-mistral-7b
|
||||
- open-mixtral-8x7b
|
||||
- open-mixtral-8x22b
|
||||
- mistral-small-latest
|
||||
- mistral-medium-latest
|
||||
- mistral-large-latest
|
||||
|
||||
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
||||
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
||||
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
||||
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
||||
@ -0,0 +1,51 @@
|
||||
model: open-mixtral-8x22b
|
||||
label:
|
||||
zh_Hans: open-mixtral-8x22b
|
||||
en_US: open-mixtral-8x22b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
default: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.002'
|
||||
output: '0.006'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -6,6 +6,7 @@ model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
|
||||
@ -5,6 +5,9 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
|
||||
@ -5,6 +5,9 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
|
||||
@ -5,6 +5,9 @@ label:
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
- google/gemma-7b
|
||||
- google/codegemma-7b
|
||||
- meta/llama2-70b
|
||||
- meta/llama3-8b
|
||||
- meta/llama3-70b
|
||||
- mistralai/mixtral-8x7b-instruct-v0.1
|
||||
- fuyu-8b
|
||||
|
||||
@ -11,13 +11,19 @@ model_properties:
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
|
||||
@ -22,6 +22,6 @@ parameter_rules:
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
|
||||
@ -11,13 +11,19 @@ model_properties:
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
|
||||
@ -7,17 +7,23 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
model: meta/llama3-70b
|
||||
label:
|
||||
zh_Hans: meta/llama3-70b
|
||||
en_US: meta/llama3-70b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
@ -0,0 +1,36 @@
|
||||
model: meta/llama3-8b
|
||||
label:
|
||||
zh_Hans: meta/llama3-8b
|
||||
en_US: meta/llama3-8b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
@ -25,7 +25,10 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
'mistralai/mixtral-8x7b-instruct-v0.1': '',
|
||||
'google/gemma-7b': '',
|
||||
'google/codegemma-7b': '',
|
||||
'meta/llama2-70b': ''
|
||||
'meta/llama2-70b': '',
|
||||
'meta/llama3-8b': '',
|
||||
'meta/llama3-70b': ''
|
||||
|
||||
}
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
@ -131,7 +134,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60)
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
@ -232,7 +235,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60),
|
||||
timeout=(10, 300),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
|
||||
@ -11,13 +11,19 @@ model_properties:
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
provider: nvidia
|
||||
label:
|
||||
en_US: API Catalog
|
||||
description:
|
||||
en_US: API Catalog
|
||||
zh_Hans: API Catalog
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
|
||||
@ -201,7 +201,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60),
|
||||
timeout=(10, 300),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
|
||||
@ -138,7 +138,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60)
|
||||
timeout=(10, 300)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
@ -154,7 +154,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
json_result['object'] = 'chat.completion'
|
||||
elif (completion_type is LLMMode.COMPLETION and json_result['object'] == ''):
|
||||
json_result['object'] = 'text_completion'
|
||||
|
||||
|
||||
if (completion_type is LLMMode.CHAT
|
||||
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
|
||||
raise CredentialsValidateFailedError(
|
||||
@ -334,7 +334,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 60),
|
||||
timeout=(10, 300),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
@ -425,6 +425,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
finish_reason = 'Unknown'
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(':'):
|
||||
|
||||
@ -73,3 +73,22 @@ model_credential_schema:
|
||||
value: llm
|
||||
default: "4096"
|
||||
type: text-input
|
||||
- variable: vision_support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
zh_Hans: 是否支持 Vision
|
||||
en_US: Vision Support
|
||||
type: radio
|
||||
required: false
|
||||
default: 'no_support'
|
||||
options:
|
||||
- value: 'support'
|
||||
label:
|
||||
en_US: 'Yes'
|
||||
zh_Hans: 是
|
||||
- value: 'no_support'
|
||||
label:
|
||||
en_US: 'No'
|
||||
zh_Hans: 否
|
||||
|
||||
@ -47,17 +47,8 @@ class XinferenceRerankModel(RerankModel):
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a rerank model')
|
||||
|
||||
response = xinference_client.rerank(
|
||||
handle = RESTfulRerankModelHandle(credentials['model_uid'], credentials['server_url'],auth_headers={})
|
||||
response = handle.rerank(
|
||||
documents=docs,
|
||||
query=query,
|
||||
top_n=top_n,
|
||||
@ -97,6 +88,20 @@ class XinferenceRerankModel(RerankModel):
|
||||
try:
|
||||
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulRerankModelHandle):
|
||||
raise InvokeBadRequestError(
|
||||
'please check model type, the model you want to invoke is not a rerank model')
|
||||
|
||||
self.invoke(
|
||||
model=model,
|
||||
@ -157,4 +162,4 @@ class XinferenceRerankModel(RerankModel):
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
return entity
|
||||
|
||||
@ -47,6 +47,20 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulAudioModelHandle):
|
||||
raise InvokeBadRequestError(
|
||||
'please check model type, the model you want to invoke is not a audio model')
|
||||
|
||||
audio_file_path = self._get_demo_file_path()
|
||||
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
@ -110,17 +124,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulAudioModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model')
|
||||
|
||||
response = xinference_client.transcriptions(
|
||||
handle = RESTfulAudioModelHandle(credentials['model_uid'],credentials['server_url'],auth_headers={})
|
||||
response = handle.transcriptions(
|
||||
audio=file,
|
||||
language = language,
|
||||
prompt = prompt,
|
||||
|
||||
@ -47,17 +47,8 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
raise InvokeAuthorizationError(e)
|
||||
|
||||
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
||||
|
||||
try:
|
||||
handle = RESTfulEmbeddingModelHandle(model_uid, server_url, auth_headers={})
|
||||
embeddings = handle.create_embedding(input=texts)
|
||||
except RuntimeError as e:
|
||||
raise InvokeServerUnavailableError(e)
|
||||
@ -122,6 +113,18 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
if extra_args.max_tokens:
|
||||
credentials['max_tokens'] = extra_args.max_tokens
|
||||
if server_url.endswith('/'):
|
||||
server_url = server_url[:-1]
|
||||
|
||||
client = Client(base_url=server_url)
|
||||
|
||||
try:
|
||||
handle = client.get_model(model_uid=model_uid)
|
||||
except RuntimeError as e:
|
||||
raise InvokeAuthorizationError(e)
|
||||
|
||||
if not isinstance(handle, RESTfulEmbeddingModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a text embedding model')
|
||||
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
except InvokeAuthorizationError as e:
|
||||
@ -198,4 +201,4 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
return entity
|
||||
|
||||
@ -1,6 +1,15 @@
|
||||
|
||||
from .__version__ import __version__
|
||||
from ._client import ZhipuAI
|
||||
from .core._errors import (APIAuthenticationError, APIInternalError, APIReachLimitError, APIRequestFailedError,
|
||||
APIResponseError, APIResponseValidationError, APIServerFlowExceedError, APIStatusError,
|
||||
APITimeoutError, ZhipuAIError)
|
||||
from .core._errors import (
|
||||
APIAuthenticationError,
|
||||
APIInternalError,
|
||||
APIReachLimitError,
|
||||
APIRequestFailedError,
|
||||
APIResponseError,
|
||||
APIResponseValidationError,
|
||||
APIServerFlowExceedError,
|
||||
APIStatusError,
|
||||
APITimeoutError,
|
||||
ZhipuAIError,
|
||||
)
|
||||
|
||||
@ -31,7 +31,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
@ -51,6 +54,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
query_prompt_template=query_prompt_template,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
@ -119,7 +123,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: Optional[str],
|
||||
memory_config: Optional[MemoryConfig],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
query_prompt_template: Optional[str] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Get chat model prompt messages.
|
||||
"""
|
||||
@ -146,6 +151,20 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
||||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||
|
||||
if query and query_prompt_template:
|
||||
prompt_template = PromptTemplateParser(
|
||||
template=query_prompt_template,
|
||||
with_variable_tmpl=self.with_variable_tmpl
|
||||
)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
prompt_inputs['#sys.query#'] = query
|
||||
|
||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
||||
|
||||
query = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
|
||||
@ -40,3 +40,4 @@ class MemoryConfig(BaseModel):
|
||||
|
||||
role_prefix: Optional[RolePrefix] = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: Optional[str] = None
|
||||
|
||||
@ -55,6 +55,8 @@ class SimplePromptTransform(PromptTransform):
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> \
|
||||
tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
prompt_messages, stops = self._get_chat_model_prompt_messages(
|
||||
|
||||
@ -110,19 +110,37 @@ class MilvusVector(BaseVector):
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
from pymilvus import utility
|
||||
if utility.has_collection(self._collection_name, using=alias):
|
||||
|
||||
ids = self.get_ids_by_metadata_field(key, value)
|
||||
if ids:
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {doc_ids}',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
from pymilvus import utility
|
||||
if utility.has_collection(self._collection_name, using=alias):
|
||||
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {doc_ids}',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
alias = uuid4().hex
|
||||
|
||||
@ -50,7 +50,8 @@ class QdrantConfig(BaseModel):
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
'timeout': self.timeout,
|
||||
'verify': self.endpoint.startswith('https')
|
||||
}
|
||||
|
||||
|
||||
@ -217,29 +218,38 @@ class QdrantVector(BaseVector):
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete(self):
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
@ -257,29 +267,40 @@ class QdrantVector(BaseVector):
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
|
||||
for node_id in ids:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
try:
|
||||
filter = models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="metadata.doc_id",
|
||||
match=models.MatchValue(value=node_id),
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
],
|
||||
)
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=FilterSelector(
|
||||
filter=filter
|
||||
),
|
||||
)
|
||||
)
|
||||
except UnexpectedResponse as e:
|
||||
# Collection does not exist, so return
|
||||
if e.status_code == 404:
|
||||
return
|
||||
# Some other error occurred, so re-raise the exception
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
all_collection_name = []
|
||||
|
||||
@ -121,18 +121,20 @@ class WeaviateVector(BaseVector):
|
||||
return ids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
|
||||
where_filter = {
|
||||
"operator": "Equal",
|
||||
"path": [key],
|
||||
"valueText": value
|
||||
}
|
||||
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
self._client.batch.delete_objects(
|
||||
class_name=self._collection_name,
|
||||
where=where_filter,
|
||||
output='minimal'
|
||||
)
|
||||
|
||||
def delete(self):
|
||||
# check whether the index already exists
|
||||
@ -163,11 +165,14 @@ class WeaviateVector(BaseVector):
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
for uuid in ids:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
for uuid in ids:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
- searxng
|
||||
- dalle
|
||||
- azuredalle
|
||||
- stability
|
||||
- wikipedia
|
||||
- model.openai
|
||||
- model.google
|
||||
@ -17,6 +18,7 @@
|
||||
- model.zhipuai
|
||||
- aippt
|
||||
- youtube
|
||||
- code
|
||||
- wolframalpha
|
||||
- maths
|
||||
- github
|
||||
|
||||
@ -44,27 +44,31 @@ class BingSearchTool(BuiltinTool):
|
||||
results = []
|
||||
if search_results:
|
||||
for result in search_results:
|
||||
url = f': {result["url"]}' if "url" in result else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{result["name"]}: {result["url"]}'
|
||||
text=f'{result["name"]}{url}'
|
||||
))
|
||||
|
||||
|
||||
if entities:
|
||||
for entity in entities:
|
||||
url = f': {entity["url"]}' if "url" in entity else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{entity["name"]}: {entity["url"]}'
|
||||
text=f'{entity.get("name", "")}{url}'
|
||||
))
|
||||
|
||||
if news:
|
||||
for news_item in news:
|
||||
url = f': {news_item["url"]}' if "url" in news_item else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{news_item["name"]}: {news_item["url"]}'
|
||||
text=f'{news_item.get("name", "")}{url}'
|
||||
))
|
||||
|
||||
if related_searches:
|
||||
for related in related_searches:
|
||||
url = f': {related["displayText"]}' if "displayText" in related else ""
|
||||
results.append(self.create_text_message(
|
||||
text=f'{related["displayText"]}: {related["webSearchUrl"]}'
|
||||
text=f'{related.get("displayText", "")}{url}'
|
||||
))
|
||||
|
||||
return results
|
||||
@ -73,7 +77,7 @@ class BingSearchTool(BuiltinTool):
|
||||
text = ''
|
||||
if search_results:
|
||||
for i, result in enumerate(search_results):
|
||||
text += f'{i+1}: {result["name"]} - {result["snippet"]}\n'
|
||||
text += f'{i+1}: {result.get("name", "")} - {result.get("snippet", "")}\n'
|
||||
|
||||
if computation and 'expression' in computation and 'value' in computation:
|
||||
text += '\nComputation:\n'
|
||||
@ -82,17 +86,20 @@ class BingSearchTool(BuiltinTool):
|
||||
if entities:
|
||||
text += '\nEntities:\n'
|
||||
for entity in entities:
|
||||
text += f'{entity["name"]} - {entity["url"]}\n'
|
||||
url = f'- {entity["url"]}' if "url" in entity else ""
|
||||
text += f'{entity.get("name", "")}{url}\n'
|
||||
|
||||
if news:
|
||||
text += '\nNews:\n'
|
||||
for news_item in news:
|
||||
text += f'{news_item["name"]} - {news_item["url"]}\n'
|
||||
url = f'- {news_item["url"]}' if "url" in news_item else ""
|
||||
text += f'{news_item.get("name", "")}{url}\n'
|
||||
|
||||
if related_searches:
|
||||
text += '\n\nRelated Searches:\n'
|
||||
for related in related_searches:
|
||||
text += f'{related["displayText"]} - {related["webSearchUrl"]}\n'
|
||||
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else ""
|
||||
text += f'{related.get("displayText", "")}{url}\n'
|
||||
|
||||
return self.create_text_message(text=self.summary(user_id=user_id, content=text))
|
||||
|
||||
|
||||
1
api/core/tools/provider/builtin/code/_assets/icon.svg
Normal file
1
api/core/tools/provider/builtin/code/_assets/icon.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg" class="w-3.5 h-3.5" data-icon="Code" aria-hidden="true"><g id="icons/code"><path id="Vector (Stroke)" fill-rule="evenodd" clip-rule="evenodd" d="M8.32593 1.69675C8.67754 1.78466 8.89132 2.14096 8.80342 2.49257L6.47009 11.8259C6.38218 12.1775 6.02588 12.3913 5.67427 12.3034C5.32265 12.2155 5.10887 11.8592 5.19678 11.5076L7.53011 2.17424C7.61801 1.82263 7.97431 1.60885 8.32593 1.69675ZM3.96414 4.20273C4.22042 4.45901 4.22042 4.87453 3.96413 5.13081L2.45578 6.63914C2.45577 6.63915 2.45578 6.63914 2.45578 6.63914C2.25645 6.83851 2.25643 7.16168 2.45575 7.36103C2.45574 7.36103 2.45576 7.36104 2.45575 7.36103L3.96413 8.86936C4.22041 9.12564 4.22042 9.54115 3.96414 9.79744C3.70787 10.0537 3.29235 10.0537 3.03607 9.79745L1.52769 8.28913C0.815811 7.57721 0.815803 6.42302 1.52766 5.7111L3.03606 4.20272C3.29234 3.94644 3.70786 3.94644 3.96414 4.20273ZM10.0361 4.20273C10.2923 3.94644 10.7078 3.94644 10.9641 4.20272L12.4725 5.71108C13.1843 6.423 13.1844 7.57717 12.4725 8.28909L10.9641 9.79745C10.7078 10.0537 10.2923 10.0537 10.036 9.79744C9.77977 9.54115 9.77978 9.12564 10.0361 8.86936L11.5444 7.36107C11.7437 7.16172 11.7438 6.83854 11.5444 6.63917C11.5444 6.63915 11.5445 6.63918 11.5444 6.63917L10.0361 5.13081C9.77978 4.87453 9.77978 4.45901 10.0361 4.20273Z" fill="currentColor"></path></g></svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user