mirror of
https://github.com/langgenius/dify.git
synced 2026-01-26 06:45:45 +08:00
Compare commits
205 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eee95190cc | |||
| e8311357ff | |||
| 0f14fdd4c9 | |||
| ece0f08a2b | |||
| 5edb3d55e5 | |||
| 63382f758e | |||
| bbef964eb5 | |||
| e6db7ad1d5 | |||
| 8cc492721b | |||
| a80fe20456 | |||
| f7986805c6 | |||
| aa5ca90f00 | |||
| 4af00e4a45 | |||
| c01c95d77f | |||
| 20a9037d5b | |||
| 34d3998566 | |||
| 198d6c00d6 | |||
| 1663df8a05 | |||
| d8926a2571 | |||
| 4796f9d914 | |||
| a588df4371 | |||
| 2c1c660c6e | |||
| 13f4ed6e0e | |||
| 1e451991db | |||
| 749b236d3d | |||
| 00ce372b71 | |||
| 370e1c1a17 | |||
| 28495273b4 | |||
| 36a9c5cc6b | |||
| 228de1f12a | |||
| 01555463d2 | |||
| 6b99075dc8 | |||
| 8578ee0864 | |||
| 897e07f639 | |||
| 875249eb00 | |||
| 4d5a4e4cef | |||
| 86a6e6bd04 | |||
| 8f3042e5b3 | |||
| a1ab87107b | |||
| f49c99937c | |||
| 9b24f12bf5 | |||
| 487ce7c82a | |||
| cc835d523c | |||
| 64c3bc070a | |||
| 7405b2e819 | |||
| ca5081e327 | |||
| a79941df22 | |||
| 8137d63000 | |||
| 4aa21242b6 | |||
| 8ce93faf08 | |||
| 903ece6160 | |||
| 9f440c11e0 | |||
| 58bd5627bf | |||
| 97dcb8977a | |||
| 2fdd64c1b5 | |||
| 591b993685 | |||
| 543a00e597 | |||
| f361c7004d | |||
| bb7c62777d | |||
| d51f52a649 | |||
| e353809680 | |||
| c2f0f958ef | |||
| 087b7a6607 | |||
| 6271463240 | |||
| 6f1911533c | |||
| d5d8b98d82 | |||
| e7fe7ec0f6 | |||
| 049abd698f | |||
| 45d21677a0 | |||
| 76bec6ce7f | |||
| 6563cb6ec6 | |||
| 13cd409575 | |||
| 13292ff73e | |||
| 3f8e2456f7 | |||
| 822ee7db88 | |||
| 94a650475d | |||
| 03cf00422a | |||
| 51a9e678f0 | |||
| ad76ee76a8 | |||
| 630136b5b7 | |||
| b5f101bdac | |||
| 5940564d84 | |||
| 67902b5da7 | |||
| c0476c7881 | |||
| f68b6b0e5e | |||
| 44857702ae | |||
| b1399cd5f9 | |||
| 6f1e4a19a2 | |||
| 93393e005e | |||
| 4ea2755fce | |||
| ecb51a83d4 | |||
| 093b5c0e63 | |||
| bf42b0ae44 | |||
| 342b4fd19d | |||
| cbdb861ee4 | |||
| da5a8b9a59 | |||
| 1e6e8b446d | |||
| c1fdaa6ae0 | |||
| 142814d451 | |||
| 704755d005 | |||
| d1263700c0 | |||
| 0704fe9695 | |||
| 1d3f1d88ef | |||
| 8b3edac091 | |||
| 05cab85579 | |||
| b72fbe200d | |||
| b1194da6a5 | |||
| 338e4669e5 | |||
| c5e2659771 | |||
| 1d432728ac | |||
| 2fd702a319 | |||
| f26ad16af7 | |||
| 8f2ae51fe5 | |||
| 2f84d00300 | |||
| b82a2d97ef | |||
| 3e9dbe3e0a | |||
| 975b2fb79e | |||
| fa509ce64e | |||
| 99292edd46 | |||
| 3e992cb23c | |||
| e7b4d024ee | |||
| ff67a6d338 | |||
| 8e4989ed03 | |||
| 0940f01634 | |||
| 9d1cb1bc92 | |||
| 0ca4e30b19 | |||
| ba88f8a6f0 | |||
| aefe0cbf51 | |||
| 9ad489d133 | |||
| 661b30784e | |||
| 43a5ba9415 | |||
| 08a65d74d5 | |||
| cefe156811 | |||
| 3b5b4d628b | |||
| 8746e48df0 | |||
| 0ec8b57825 | |||
| 045827043d | |||
| 4d66a86579 | |||
| 2a8881d0e8 | |||
| ffc60bb917 | |||
| 2e454c770b | |||
| 7d711135bc | |||
| f62b2b5b45 | |||
| 7919596a21 | |||
| 9b4898efeb | |||
| 45dd1683fd | |||
| 8bca908f15 | |||
| 9cbb8ddd7f | |||
| 1be222af2e | |||
| bf9fc8fef4 | |||
| 86e7330fa2 | |||
| 34bfb715e1 | |||
| 019d7069f8 | |||
| c54fcfb45d | |||
| cde87cb225 | |||
| 12435774ca | |||
| 80b9507e7a | |||
| 0ac0f0ffd0 | |||
| 3d14aba4b4 | |||
| 64f694865c | |||
| d36b728088 | |||
| 1a7b4c42ab | |||
| 2a64ce740e | |||
| 78988ed60e | |||
| 2832adda88 | |||
| a4e4fb4094 | |||
| 777ec64635 | |||
| 9cec8c1750 | |||
| 8ca5aa1190 | |||
| 4d8f1b9ca4 | |||
| 3da179f77b | |||
| a34e8cb0bd | |||
| b249767c5c | |||
| 89a7434565 | |||
| 3b537cbdeb | |||
| 731464f5b8 | |||
| 1ad70f8721 | |||
| 2ea8c73cd8 | |||
| f257f2c396 | |||
| 3cd8e6f5c6 | |||
| 0715db7681 | |||
| a39de8a686 | |||
| ccaf335466 | |||
| 40e36e9b52 | |||
| 9eebe9d54e | |||
| a23a191615 | |||
| 7d9c5586f9 | |||
| f07c89bba4 | |||
| 59cba930e5 | |||
| 39ae56e136 | |||
| f92130338b | |||
| 2867d29021 | |||
| f76ac8bdee | |||
| 83caffe000 | |||
| 96160837d2 | |||
| 3480f1c59e | |||
| 65ac4f69af | |||
| 2c50fab3dd | |||
| 9525ccac4f | |||
| ff76c4bd5d | |||
| 5dacf77627 | |||
| 2a213c6af7 | |||
| b2535e7db6 | |||
| 28236147ee | |||
| 4969783383 |
@ -1,4 +1,4 @@
|
||||
# Devlopment with devcontainer
|
||||
# Development with devcontainer
|
||||
This project includes a devcontainer configuration that allows you to open the project in a container with a fully configured development environment.
|
||||
Both frontend and backend environments are initialized when the container is started.
|
||||
## GitHub Codespaces
|
||||
@ -33,5 +33,5 @@ Performance Impact: While usually minimal, programs running inside a devcontaine
|
||||
if you see such error message when you open this project in codespaces:
|
||||

|
||||
|
||||
a simple workaround is change `/signin` endpoint into another one, then login with github account and close the tab, then change it back to `/signin` endpoint. Then all things will be fine.
|
||||
a simple workaround is change `/signin` endpoint into another one, then login with GitHub account and close the tab, then change it back to `/signin` endpoint. Then all things will be fine.
|
||||
The reason is `signin` endpoint is not allowed in codespaces, details can be found [here](https://github.com/orgs/community/discussions/5204)
|
||||
@ -32,8 +32,8 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"postStartCommand": "cd api && pip install -r requirements.txt",
|
||||
"postCreateCommand": "cd web && npm install"
|
||||
"postStartCommand": "./.devcontainer/post_start_command.sh",
|
||||
"postCreateCommand": "./.devcontainer/post_create_command.sh"
|
||||
|
||||
// Features to add to the dev container. More info: https://containers.dev/features.
|
||||
// "features": {},
|
||||
|
||||
10
.devcontainer/post_create_command.sh
Executable file
10
.devcontainer/post_create_command.sh
Executable file
@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd web && npm install
|
||||
|
||||
echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
|
||||
echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail"' >> ~/.bashrc
|
||||
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
|
||||
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
3
.devcontainer/post_start_command.sh
Executable file
3
.devcontainer/post_start_command.sh
Executable file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd api && pip install -r requirements.txt
|
||||
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -8,13 +8,13 @@ body:
|
||||
label: Self Checks
|
||||
description: "To make sure we get to you in time, please check the following :)"
|
||||
options:
|
||||
- label: This is only for bug report, if you would like to ask a quesion, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
- label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
|
||||
required: true
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
4
.github/ISSUE_TEMPLATE/document_issue.yml
vendored
@ -1,7 +1,7 @@
|
||||
name: "📚 Documentation Issue"
|
||||
description: Report issues in our documentation
|
||||
labels:
|
||||
- ducumentation
|
||||
- documentation
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
@ -12,7 +12,7 @@ body:
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@ -12,7 +12,7 @@ body:
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
2
.github/ISSUE_TEMPLATE/translation_issue.yml
vendored
@ -12,7 +12,7 @@ body:
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
|
||||
required: true
|
||||
- label: "Pleas do not modify this template :) and fill in all the required fields."
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: input
|
||||
attributes:
|
||||
|
||||
61
.github/workflows/api-tests.yml
vendored
61
.github/workflows/api-tests.yml
vendored
@ -10,36 +10,14 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii
|
||||
AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com
|
||||
AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94
|
||||
ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
CHATGLM_API_BASE: http://a.abc.com:11451
|
||||
XINFERENCE_SERVER_URL: http://a.abc.com:11451
|
||||
XINFERENCE_GENERATION_MODEL_UID: generate
|
||||
XINFERENCE_CHAT_MODEL_UID: chat
|
||||
XINFERENCE_EMBEDDINGS_MODEL_UID: embedding
|
||||
XINFERENCE_RERANK_MODEL_UID: rerank
|
||||
GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz
|
||||
HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
|
||||
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b
|
||||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c
|
||||
MOCK_SWITCH: true
|
||||
CODE_MAX_STRING_LENGTH: 80000
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install APT packages
|
||||
uses: awalsh128/cache-apt-pkgs-action@v1
|
||||
with:
|
||||
packages: ffmpeg
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@ -52,11 +30,44 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pip install -r ./api/requirements.txt -r ./api/requirements-dev.txt
|
||||
|
||||
- name: Run Unit tests
|
||||
run: dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run Tool
|
||||
run: dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
||||
- name: Run Workflow
|
||||
run: dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
docker/docker-compose.qdrant.yaml
|
||||
docker/docker-compose.milvus.yaml
|
||||
docker/docker-compose.pgvecto-rs.yaml
|
||||
docker/docker-compose.pgvector.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: dev/pytest/pytest_vdb.sh
|
||||
|
||||
@ -4,7 +4,7 @@ We need to be nimble and ship fast given where we are, but we also want to make
|
||||
|
||||
This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve.
|
||||
|
||||
In terms of licensing, please take a minute to read our short [License and Contributor Agreement](./license). The community also adheres to the [code of conduct](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md).
|
||||
In terms of licensing, please take a minute to read our short [License and Contributor Agreement](./LICENSE). The community also adheres to the [code of conduct](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
## Before you jump in
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
这份指南,就像 Dify 本身一样,是一个不断改进的工作。如果有时它落后于实际项目,我们非常感谢你的理解,并欢迎任何反馈以供我们改进。
|
||||
|
||||
在许可方面,请花一分钟阅读我们简短的[许可证和贡献者协议](./license)。社区还遵守[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
在许可方面,请花一分钟阅读我们简短的[许可证和贡献者协议](./LICENSE)。社区还遵守[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
|
||||
## 在开始之前
|
||||
|
||||
|
||||
22
README.md
22
README.md
@ -29,19 +29,15 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="Commits last month" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="Commits last month" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="Commits last month" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="Commits last month" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="Commits last month" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="Commits last month" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
#
|
||||
|
||||
<p align="center">
|
||||
<a href="https://trendshift.io/repositories/2152" target="_blank"><img src="https://trendshift.io/api/badge/repositories/2152" alt="langgenius%2Fdify | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</p>
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
|
||||
</br> </br>
|
||||
|
||||
@ -54,7 +50,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
@ -109,7 +105,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
@ -127,7 +123,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">Enterprise Feature (SSO/Access control)</td>
|
||||
<td align="center">Enterprise Features (SSO/Access control)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
|
||||
@ -111,7 +111,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
@ -54,7 +54,7 @@ Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto.
|
||||
|
||||
|
||||
**2. Soporte de modelos completo**:
|
||||
Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama2 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama3 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
@ -111,7 +111,7 @@ es basados en LLM Function Calling o ReAct, y agregar herramientas preconstruida
|
||||
<td align="center">Agente</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
@ -54,7 +54,7 @@ Dify est une plateforme de développement d'applications LLM open source. Son in
|
||||
|
||||
|
||||
**2. Prise en charge complète des modèles**:
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama2, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
@ -111,7 +111,7 @@ ités d'agent**:
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
40
README_JA.md
40
README_JA.md
@ -2,7 +2,7 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">自己ホスティング</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">セルフホスト</a> ·
|
||||
<a href="https://docs.dify.ai">ドキュメント</a> ·
|
||||
<a href="https://cal.com/guchenhe/dify-demo">デモのスケジュール</a>
|
||||
</p>
|
||||
@ -54,10 +54,8 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
|
||||
|
||||
|
||||
**2. 網羅的なモデルサポート**:
|
||||
数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama2、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs
|
||||
|
||||
.dify.ai/getting-started/readme/model-providers)をご覧ください。
|
||||
**2. 包括的なモデルサポート**:
|
||||
数百のプロプライエタリ/オープンソースのLLMと、数十の推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、およびOpenAI API互換のモデルをカバーします。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。
|
||||
|
||||

|
||||
|
||||
@ -96,9 +94,9 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">サポートされているLLM</td>
|
||||
<td align="center">豊富なバリエーション</td>
|
||||
<td align="center">豊富なバリエーション</td>
|
||||
<td align="center">豊富なバリエーション</td>
|
||||
<td align="center">バリエーション豊富</td>
|
||||
<td align="center">バリエーション豊富</td>
|
||||
<td align="center">バリエーション豊富</td>
|
||||
<td align="center">OpenAIのみ</td>
|
||||
</tr>
|
||||
<tr>
|
||||
@ -112,7 +110,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
<td align="center">エージェント</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
@ -148,36 +146,34 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
## Difyの使用方法
|
||||
|
||||
- **クラウド </br>**
|
||||
[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップが不要で誰でも試すことができます。サンドボックスプランでは、200回の無料のGPT-4呼び出しが含まれています。
|
||||
[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回の無料のGPT-4呼び出しが含まれています。
|
||||
|
||||
- **Dify Community Editionのセルフホスティング</br>**
|
||||
この[スターターガイド](#quick-start)を使用して、環境でDifyをすばやく実行できます。
|
||||
さらなる参照や詳細な手順については、[ドキュメント](https://docs.dify.ai)をご覧ください。
|
||||
この[スターターガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。
|
||||
さらなる参考資料や詳細な手順については、[ドキュメント](https://docs.dify.ai)をご覧ください。
|
||||
|
||||
- **エンタープライズ/組織向けのDify</br>**
|
||||
追加のエンタープライズ向け機能を提供しています。[こちらからミーティ
|
||||
|
||||
ングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。 </br>
|
||||
追加のエンタープライズ向け機能を提供しています。[こちらからミーティングを予約](https://cal.com/guchenhe/30min)したり、[メールを送信](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)してエンタープライズのニーズについて相談してください。 </br>
|
||||
> AWSを使用しているスタートアップや中小企業の場合は、[AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6)のDify Premiumをチェックして、ワンクリックで独自のAWS VPCにデプロイできます。カスタムロゴとブランディングでアプリを作成するオプションを備えた手頃な価格のAMIオファリングです。
|
||||
|
||||
|
||||
## 先を見る
|
||||
## 最新の情報を入手
|
||||
|
||||
GitHubでDifyにスターを付け、新しいリリースをすぐに通知されます。
|
||||
GitHub上でDifyにスターを付けることで、Difyに関する新しいニュースを受け取れます。
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
## クイックスタート
|
||||
> Difyをインストールする前に、マシンが以下の最小システム要件を満たしていることを確認してください:
|
||||
> Difyをインストールする前に、お使いのマシンが以下の最小システム要件を満たしていることを確認してください:
|
||||
>
|
||||
>- CPU >= 2コア
|
||||
>- RAM >= 4GB
|
||||
|
||||
</br>
|
||||
|
||||
Difyサーバーを起動する最も簡単な方法は、当社の[docker-compose.yml](docker/docker-compose.yaml)ファイルを実行することです。インストールコマンドを実行する前に、マシンに[Docker](https://docs.docker.com/get-docker/)と[Docker Compose](https://docs.docker.com/compose/install/)がインストールされていることを確認してください。
|
||||
Difyサーバーを起動する最も簡単な方法は、[docker-compose.yml](docker/docker-compose.yaml)ファイルを実行することです。インストールコマンドを実行する前に、マシンに[Docker](https://docs.docker.com/get-docker/)と[Docker Compose](https://docs.docker.com/compose/install/)がインストールされていることを確認してください。
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
@ -220,7 +216,7 @@ docker compose up -d
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。
|
||||
* [Twitter](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。
|
||||
|
||||
または、直接チームメンバーとミーティングをスケジュールします:
|
||||
または、直接チームメンバーとミーティングをスケジュール:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
@ -231,7 +227,7 @@ docker compose up -d
|
||||
<td><a href='https://cal.com
|
||||
|
||||
/guchenhe/30min'>ミーティング</a></td>
|
||||
<td>無料の30分間のミーティングをスケジュールしてください。</td>
|
||||
<td>無料の30分間のミーティングをスケジュール</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><a href='mailto:support@dify.ai?subject=[GitHub]Technical%20Support'>技術サポート</a></td>
|
||||
@ -246,4 +242,4 @@ docker compose up -d
|
||||
|
||||
## ライセンス
|
||||
|
||||
プロジェクトはMITライセンスの下で利用可能です。[LICENSE](LICENSE)をご参照ください。
|
||||
このリポジトリは、Dify Open Source License にいくつかの追加制限を加えた[Difyオープンソースライセンス](LICENSE)の下で利用可能です。
|
||||
|
||||
@ -54,7 +54,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama2, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
@ -111,7 +111,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
|
||||
<td align="center">Agent</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
||||
@ -1,6 +1,3 @@
|
||||
# Server Edition
|
||||
EDITION=SELF_HOSTED
|
||||
|
||||
# Your App secret key will be used for securely signing the session cookie
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
# You can generate a strong key using `openssl rand -base64 42`.
|
||||
@ -52,12 +49,23 @@ AZURE_BLOB_ACCOUNT_NAME=your-account-name
|
||||
AZURE_BLOB_ACCOUNT_KEY=your-account-key
|
||||
AZURE_BLOB_CONTAINER_NAME=yout-container-name
|
||||
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
|
||||
# Aliyun oss Storage configuration
|
||||
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
|
||||
ALIYUN_OSS_ACCESS_KEY=your-access-key
|
||||
ALIYUN_OSS_SECRET_KEY=your-secret-key
|
||||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON=your-google-service-account-json-base64-string
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, relyt
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, relyt, pgvecto_rs, pgvector
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -70,6 +78,8 @@ WEAVIATE_BATCH_SIZE=100
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_API_KEY=difyai123456
|
||||
QDRANT_CLIENT_TIMEOUT=20
|
||||
QDRANT_GRPC_ENABLED=false
|
||||
QDRANT_GRPC_PORT=6334
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_HOST=127.0.0.1
|
||||
@ -85,6 +95,20 @@ RELYT_USER=postgres
|
||||
RELYT_PASSWORD=postgres
|
||||
RELYT_DATABASE=postgres
|
||||
|
||||
# PGVECTO_RS configuration
|
||||
PGVECTO_RS_HOST=localhost
|
||||
PGVECTO_RS_PORT=5431
|
||||
PGVECTO_RS_USER=postgres
|
||||
PGVECTO_RS_PASSWORD=difyai123456
|
||||
PGVECTO_RS_DATABASE=postgres
|
||||
|
||||
# PGVector configuration
|
||||
PGVECTOR_HOST=127.0.0.1
|
||||
PGVECTOR_PORT=5433
|
||||
PGVECTOR_USER=postgres
|
||||
PGVECTOR_PASSWORD=postgres
|
||||
PGVECTOR_DATABASE=postgres
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
@ -118,25 +142,6 @@ NOTION_CLIENT_SECRET=you-client-secret
|
||||
NOTION_CLIENT_ID=you-client-id
|
||||
NOTION_INTERNAL_SECRET=you-internal-secret
|
||||
|
||||
# Hosted Model Credentials
|
||||
HOSTED_OPENAI_API_KEY=
|
||||
HOSTED_OPENAI_API_BASE=
|
||||
HOSTED_OPENAI_API_ORGANIZATION=
|
||||
HOSTED_OPENAI_TRIAL_ENABLED=false
|
||||
HOSTED_OPENAI_QUOTA_LIMIT=200
|
||||
HOSTED_OPENAI_PAID_ENABLED=false
|
||||
|
||||
HOSTED_AZURE_OPENAI_ENABLED=false
|
||||
HOSTED_AZURE_OPENAI_API_KEY=
|
||||
HOSTED_AZURE_OPENAI_API_BASE=
|
||||
HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200
|
||||
|
||||
HOSTED_ANTHROPIC_API_BASE=
|
||||
HOSTED_ANTHROPIC_API_KEY=
|
||||
HOSTED_ANTHROPIC_TRIAL_ENABLED=false
|
||||
HOSTED_ANTHROPIC_QUOTA_LIMIT=600000
|
||||
HOSTED_ANTHROPIC_PAID_ENABLED=false
|
||||
|
||||
ETL_TYPE=dify
|
||||
UNSTRUCTURED_API_URL=
|
||||
|
||||
@ -160,3 +165,13 @@ CODE_MAX_NUMBER_ARRAY_LENGTH=1000
|
||||
# API Tool configuration
|
||||
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
|
||||
API_TOOL_DEFAULT_READ_TIMEOUT=60
|
||||
|
||||
# HTTP Node configuration
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT=300
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT=600
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT=600
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 # 10MB
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 # 1MB
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
|
||||
20
api/app.py
20
api/app.py
@ -1,28 +1,28 @@
|
||||
import os
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
# if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||
|
||||
import grpc.experimental.gevent
|
||||
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from commands import register_commands
|
||||
from config import CloudEditionConfig, Config
|
||||
from config import Config
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers
|
||||
@ -75,16 +75,9 @@ config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
|
||||
# ----------------------------
|
||||
|
||||
|
||||
def create_app(test_config=None) -> Flask:
|
||||
def create_app() -> Flask:
|
||||
app = DifyApp(__name__)
|
||||
|
||||
if test_config:
|
||||
app.config.from_object(test_config)
|
||||
else:
|
||||
if config_type == "CLOUD":
|
||||
app.config.from_object(CloudEditionConfig())
|
||||
else:
|
||||
app.config.from_object(Config())
|
||||
app.config.from_object(Config())
|
||||
|
||||
app.secret_key = app.config['SECRET_KEY']
|
||||
|
||||
@ -101,6 +94,7 @@ def create_app(test_config=None) -> Flask:
|
||||
),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=app.config.get('LOG_LEVEL'),
|
||||
format=app.config.get('LOG_FORMAT'),
|
||||
|
||||
@ -305,6 +305,14 @@ def migrate_knowledge_vector_database():
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "pgvector":
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'pgvector',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
DEFAULTS = {
|
||||
'EDITION': 'SELF_HOSTED',
|
||||
'DB_USERNAME': 'postgres',
|
||||
'DB_PASSWORD': '',
|
||||
'DB_HOST': 'localhost',
|
||||
@ -36,6 +37,8 @@ DEFAULTS = {
|
||||
'WEAVIATE_GRPC_ENABLED': 'True',
|
||||
'WEAVIATE_BATCH_SIZE': 100,
|
||||
'QDRANT_CLIENT_TIMEOUT': 20,
|
||||
'QDRANT_GRPC_ENABLED': 'False',
|
||||
'QDRANT_GRPC_PORT': '6334',
|
||||
'CELERY_BACKEND': 'database',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'LOG_FILE': '',
|
||||
@ -104,9 +107,9 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.6.4"
|
||||
self.CURRENT_VERSION = "0.6.8"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.EDITION = get_env('EDITION')
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
@ -208,10 +211,18 @@ class Config:
|
||||
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
|
||||
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
|
||||
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
|
||||
self.ALIYUN_OSS_BUCKET_NAME=get_env('ALIYUN_OSS_BUCKET_NAME')
|
||||
self.ALIYUN_OSS_ACCESS_KEY=get_env('ALIYUN_OSS_ACCESS_KEY')
|
||||
self.ALIYUN_OSS_SECRET_KEY=get_env('ALIYUN_OSS_SECRET_KEY')
|
||||
self.ALIYUN_OSS_ENDPOINT=get_env('ALIYUN_OSS_ENDPOINT')
|
||||
self.ALIYUN_OSS_REGION=get_env('ALIYUN_OSS_REGION')
|
||||
self.ALIYUN_OSS_AUTH_VERSION=get_env('ALIYUN_OSS_AUTH_VERSION')
|
||||
self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME')
|
||||
self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate, relyt
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate, relyt, pgvector
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
|
||||
@ -219,6 +230,8 @@ class Config:
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
self.QDRANT_CLIENT_TIMEOUT = get_env('QDRANT_CLIENT_TIMEOUT')
|
||||
self.QDRANT_GRPC_ENABLED = get_env('QDRANT_GRPC_ENABLED')
|
||||
self.QDRANT_GRPC_PORT = get_env('QDRANT_GRPC_PORT')
|
||||
|
||||
# milvus / zilliz setting
|
||||
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||
@ -241,6 +254,20 @@ class Config:
|
||||
self.RELYT_PASSWORD = get_env('RELYT_PASSWORD')
|
||||
self.RELYT_DATABASE = get_env('RELYT_DATABASE')
|
||||
|
||||
# pgvecto rs settings
|
||||
self.PGVECTO_RS_HOST = get_env('PGVECTO_RS_HOST')
|
||||
self.PGVECTO_RS_PORT = get_env('PGVECTO_RS_PORT')
|
||||
self.PGVECTO_RS_USER = get_env('PGVECTO_RS_USER')
|
||||
self.PGVECTO_RS_PASSWORD = get_env('PGVECTO_RS_PASSWORD')
|
||||
self.PGVECTO_RS_DATABASE = get_env('PGVECTO_RS_DATABASE')
|
||||
|
||||
# pgvector settings
|
||||
self.PGVECTOR_HOST = get_env('PGVECTOR_HOST')
|
||||
self.PGVECTOR_PORT = get_env('PGVECTOR_PORT')
|
||||
self.PGVECTOR_USER = get_env('PGVECTOR_USER')
|
||||
self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
|
||||
self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
|
||||
|
||||
# ------------------------
|
||||
# Mail Configurations.
|
||||
# ------------------------
|
||||
@ -256,7 +283,7 @@ class Config:
|
||||
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
|
||||
|
||||
# ------------------------
|
||||
# Workpace Configurations.
|
||||
# Workspace Configurations.
|
||||
# ------------------------
|
||||
self.INVITE_EXPIRY_HOURS = int(get_env('INVITE_EXPIRY_HOURS'))
|
||||
|
||||
@ -295,6 +322,12 @@ class Config:
|
||||
# ------------------------
|
||||
# Platform Configurations.
|
||||
# ------------------------
|
||||
self.GITHUB_CLIENT_ID = get_env('GITHUB_CLIENT_ID')
|
||||
self.GITHUB_CLIENT_SECRET = get_env('GITHUB_CLIENT_SECRET')
|
||||
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
|
||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
@ -341,17 +374,3 @@ class Config:
|
||||
|
||||
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
|
||||
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.EDITION = "CLOUD"
|
||||
|
||||
self.GITHUB_CLIENT_ID = get_env('GITHUB_CLIENT_ID')
|
||||
self.GITHUB_CLIENT_SECRET = get_env('GITHUB_CLIENT_SECRET')
|
||||
self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID')
|
||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
|
||||
|
||||
languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN']
|
||||
languages = ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA', 'vi-VN', 'pl-PL']
|
||||
|
||||
language_timezone_mapping = {
|
||||
'en-US': 'America/New_York',
|
||||
'zh-Hans': 'Asia/Shanghai',
|
||||
'zh-Hant': 'Asia/Taipei',
|
||||
'pt-BR': 'America/Sao_Paulo',
|
||||
'es-ES': 'Europe/Madrid',
|
||||
'fr-FR': 'Europe/Paris',
|
||||
@ -15,6 +16,7 @@ language_timezone_mapping = {
|
||||
'it-IT': 'Europe/Rome',
|
||||
'uk-UA': 'Europe/Kyiv',
|
||||
'vi-VN': 'Asia/Ho_Chi_Minh',
|
||||
'pl-PL': 'Europe/Warsaw',
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -53,5 +53,8 @@ from .explore import (
|
||||
workflow,
|
||||
)
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
|
||||
@ -1,17 +1,16 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_detail_fields,
|
||||
app_detail_fields_with_site,
|
||||
@ -20,6 +19,7 @@ from fields.app_fields import (
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_service import AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
|
||||
|
||||
@ -29,21 +29,29 @@ class AppListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_pagination_fields)
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(',')]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args')
|
||||
parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False)
|
||||
parser.add_argument('name', type=str, location='args', required=False)
|
||||
parser.add_argument('tag_ids', type=uuid_list, location='args', required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False}
|
||||
|
||||
return app_pagination
|
||||
return marshal(app_pagination, app_pagination_fields)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -108,43 +116,9 @@ class AppApi(Resource):
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
# get original app model config
|
||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
model_config: AppModelConfig = app_model.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
app_service = AppService()
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
db.session.commit()
|
||||
app_model = app_service.get_app(app_model)
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@ -91,3 +91,9 @@ class DraftWorkflowNotExist(BaseHTTPException):
|
||||
error_code = 'draft_workflow_not_exist'
|
||||
description = "Draft workflow need to be initialized."
|
||||
code = 400
|
||||
|
||||
|
||||
class DraftWorkflowNotSync(BaseHTTPException):
|
||||
error_code = 'draft_workflow_not_sync'
|
||||
description = "Workflow graph might have been modified, please refresh and resubmit."
|
||||
code = 400
|
||||
|
||||
@ -57,6 +57,7 @@ class ModelConfigResource(Resource):
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
@ -64,6 +65,7 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
@ -94,6 +96,7 @@ class ModelConfigResource(Resource):
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -104,6 +107,7 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app_model.id}'
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
@ -111,9 +115,11 @@ class ModelConfigResource(Resource):
|
||||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
|
||||
agent_tool_entity.tool_parameters = parameter_map[key]
|
||||
|
||||
for masked_key, masked_value in masked_parameter_map[key].items():
|
||||
if masked_key in agent_tool_entity.tool_parameters and \
|
||||
agent_tool_entity.tool_parameters[masked_key] == masked_value:
|
||||
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
|
||||
@ -7,7 +7,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
@ -20,6 +20,7 @@ from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -59,6 +60,7 @@ class DraftWorkflowApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('hash', type=str, required=False, location='json')
|
||||
args = parser.parse_args()
|
||||
elif 'text/plain' in content_type:
|
||||
try:
|
||||
@ -71,7 +73,8 @@ class DraftWorkflowApi(Resource):
|
||||
|
||||
args = {
|
||||
'graph': data.get('graph'),
|
||||
'features': data.get('features')
|
||||
'features': data.get('features'),
|
||||
'hash': data.get('hash')
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {'message': 'Invalid JSON data'}, 400
|
||||
@ -79,15 +82,21 @@ class DraftWorkflowApi(Resource):
|
||||
abort(415)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args.get('graph'),
|
||||
features=args.get('features'),
|
||||
account=current_user
|
||||
)
|
||||
|
||||
try:
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args.get('graph'),
|
||||
features=args.get('features'),
|
||||
unique_hash=args.get('hash'),
|
||||
account=current_user
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
|
||||
}
|
||||
|
||||
|
||||
@ -227,7 +227,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY date, c.created_by) sub
|
||||
GROUP BY sub.created_by, sub.date
|
||||
GROUP BY sub.date
|
||||
"""
|
||||
arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value}
|
||||
|
||||
|
||||
@ -48,11 +48,14 @@ class DatasetListApi(Resource):
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
ids = request.args.getlist('ids')
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
current_user.current_tenant_id, current_user)
|
||||
current_user.current_tenant_id, current_user, search, tag_ids)
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
@ -184,6 +187,10 @@ class DatasetApi(Resource):
|
||||
help='Invalid indexing technique.')
|
||||
parser.add_argument('permission', type=str, location='json', choices=(
|
||||
'only_me', 'all_team_members'), help='Invalid permission.')
|
||||
parser.add_argument('embedding_model', type=str,
|
||||
location='json', help='Invalid embedding model.')
|
||||
parser.add_argument('embedding_model_provider', type=str,
|
||||
location='json', help='Invalid embedding model provider.')
|
||||
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -469,13 +476,13 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = current_app.config['VECTOR_STORE']
|
||||
if vector_type == 'milvus':
|
||||
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
elif vector_type == 'qdrant' or vector_type == 'weaviate':
|
||||
elif vector_type in {"qdrant", "weaviate"}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
@ -490,14 +497,13 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
|
||||
if vector_type == 'milvus':
|
||||
if vector_type in {'milvus', 'relyt', 'pgvector'}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
elif vector_type == 'qdrant' or vector_type == 'weaviate':
|
||||
elif vector_type in {'qdrant', 'weaviate'}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
@ -506,10 +512,27 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
else:
|
||||
raise ValueError("Unsupported vector db type.")
|
||||
|
||||
class DatasetErrorDocs(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return {
|
||||
'data': [marshal(item, document_status_fields) for item in results],
|
||||
'total': len(results)
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
|
||||
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs')
|
||||
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
|
||||
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
|
||||
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask import request
|
||||
@ -233,7 +234,7 @@ class DatasetDocumentListApi(Resource):
|
||||
location='json')
|
||||
parser.add_argument('data_source', type=dict, required=False, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, nullable=False, location='json')
|
||||
parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json')
|
||||
parser.add_argument('original_document_id', type=str, required=False, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
@ -393,9 +394,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
def get(self, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
response = {
|
||||
"tokens": 0,
|
||||
@ -883,6 +881,49 @@ class DocumentRecoverApi(DocumentResource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentRetryApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('document_ids', type=list, required=True, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
retry_documents = []
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
for document_id in args['document_ids']:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
raise NotFound("Document Not Exists.")
|
||||
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
|
||||
# 400 if document is completed
|
||||
if document.indexing_status == 'completed':
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception as e:
|
||||
logging.error(f"Document {document_id} retry failed: {str(e)}")
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
@ -908,3 +949,4 @@ api.add_resource(DocumentStatusApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>')
|
||||
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
|
||||
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
|
||||
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
|
||||
|
||||
159
api/controllers/console/tag/tags.py
Normal file
159
api/controllers/console/tag/tags.py
Normal file
@ -0,0 +1,159 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import login_required
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError('Name must be between 1 to 50 characters.')
|
||||
return name
|
||||
|
||||
|
||||
class TagListApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(tag_fields)
|
||||
def get(self):
|
||||
tag_type = request.args.get('type', type=str)
|
||||
keyword = request.args.get('keyword', default=None, type=str)
|
||||
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': 0
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='Name must be between 1 to 50 characters.',
|
||||
type=_validate_name)
|
||||
args = parser.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = {
|
||||
'id': tag.id,
|
||||
'name': tag.name,
|
||||
'type': tag.type,
|
||||
'binding_count': binding_count
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
class TagBindingCreateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json',
|
||||
help='Tag IDs is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True, location='json',
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
class TagBindingDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('tag_id', type=str, nullable=False, required=True,
|
||||
help='Tag ID is required.')
|
||||
parser.add_argument('target_id', type=str, nullable=False, required=True,
|
||||
help='Target ID is required.')
|
||||
parser.add_argument('type', type=str, location='json',
|
||||
choices=Tag.TAG_TYPE_LIST,
|
||||
nullable=True,
|
||||
help='Invalid tag type.')
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 200
|
||||
|
||||
|
||||
api.add_resource(TagListApi, '/tags')
|
||||
api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>')
|
||||
api.add_resource(TagBindingCreateApi, '/tag-bindings/create')
|
||||
api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove')
|
||||
@ -9,7 +9,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_with_role_list_fields
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
@ -43,7 +43,7 @@ class MemberInviteEmailApi(Resource):
|
||||
invitee_emails = args['emails']
|
||||
invitee_role = args['role']
|
||||
interface_language = args['language']
|
||||
if invitee_role not in ['admin', 'normal']:
|
||||
if invitee_role not in [TenantAccountRole.ADMIN, TenantAccountRole.NORMAL]:
|
||||
return {'code': 'invalid-role', 'message': 'Invalid role'}, 400
|
||||
|
||||
inviter = current_user
|
||||
|
||||
@ -11,6 +11,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import TenantAccountRole
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
@ -94,7 +95,7 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
@ -125,7 +126,7 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@ -26,8 +26,11 @@ class DatasetApi(DatasetApiResource):
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
tag_ids = request.args.getlist('tag_ids')
|
||||
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
tenant_id, current_user)
|
||||
tenant_id, current_user, search, tag_ids)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(
|
||||
|
||||
@ -163,6 +163,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
@ -219,7 +219,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought,
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta=tool_invoke_meta.to_dict(),
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict['usage']
|
||||
|
||||
@ -18,7 +18,7 @@ from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -56,6 +56,14 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
@ -98,7 +106,8 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: query,
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
@ -8,6 +8,8 @@ from core.app.entities.task_entities import (
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
)
|
||||
|
||||
@ -111,6 +113,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
|
||||
|
||||
@ -28,9 +28,9 @@ from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamGenerateRoute,
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@ -84,13 +84,19 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
|
||||
if isinstance(self._user, EndUser):
|
||||
user_id = self._user.session_id
|
||||
else:
|
||||
user_id = self._user.id
|
||||
|
||||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.QUERY: message.query,
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.CONVERSATION: conversation.id,
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
@ -337,7 +343,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
**extras
|
||||
)
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]:
|
||||
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
@ -360,7 +366,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = StreamGenerateRoute(
|
||||
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
|
||||
answer_node_id=answer_node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
@ -424,15 +430,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
for token in route_chunk.text:
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(token)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += token
|
||||
yield self._message_to_stream_response(token, self._message.id)
|
||||
time.sleep(0.01)
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
break
|
||||
|
||||
@ -457,10 +462,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
for token in route_chunk.text:
|
||||
self._task_state.answer += token
|
||||
yield self._message_to_stream_response(token, self._message.id)
|
||||
time.sleep(0.01)
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
@ -13,7 +13,9 @@ class BaseAppGenerator:
|
||||
for variable_config in variables:
|
||||
variable = variable_config.variable
|
||||
|
||||
if variable not in user_inputs or not user_inputs[variable]:
|
||||
if (variable not in user_inputs
|
||||
or user_inputs[variable] is None
|
||||
or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')):
|
||||
if variable_config.required:
|
||||
raise ValueError(f"{variable} is required in input form")
|
||||
else:
|
||||
@ -22,21 +24,29 @@ class BaseAppGenerator:
|
||||
|
||||
value = user_inputs[variable]
|
||||
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
if value is not None:
|
||||
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
|
||||
raise ValueError(f"{variable} in input form must be a string")
|
||||
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
|
||||
if '.' in value:
|
||||
value = float(value)
|
||||
else:
|
||||
value = int(value)
|
||||
|
||||
if variable_config.type == VariableEntity.Type.SELECT:
|
||||
options = variable_config.options if variable_config.options is not None else []
|
||||
if value not in options:
|
||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
||||
else:
|
||||
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
|
||||
if variable_config.max_length is not None:
|
||||
max_length = variable_config.max_length
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
||||
|
||||
filtered_inputs[variable] = value.replace('\x00', '') if value else None
|
||||
if value and isinstance(value, str):
|
||||
filtered_inputs[variable] = value.replace('\x00', '')
|
||||
else:
|
||||
filtered_inputs[variable] = value if value is not None else None
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -36,6 +36,14 @@ class WorkflowAppRunner:
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
@ -67,7 +75,8 @@ class WorkflowAppRunner:
|
||||
else UserFrom.END_USER,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: files
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
@ -5,6 +5,8 @@ from typing import cast
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
@ -68,4 +70,24 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
return cls.convert_stream_full_response(stream_response)
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'workflow_run_id': chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@ -28,11 +28,13 @@ from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStreamGenerateNodes,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
from core.workflow.entities.node_entities import NodeType, SystemVariable
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
@ -40,6 +42,7 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
@ -71,12 +74,19 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
|
||||
if isinstance(self._user, EndUser):
|
||||
user_id = self._user.session_id
|
||||
else:
|
||||
user_id = self._user.id
|
||||
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = {
|
||||
SystemVariable.FILES: application_generate_entity.files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._stream_generate_nodes = self._get_stream_generate_nodes()
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@ -161,6 +171,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
|
||||
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -168,6 +186,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
@ -187,6 +206,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(delta_text)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
@ -248,3 +272,142 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
text=TextReplaceStreamResponse.Data(text=text)
|
||||
)
|
||||
|
||||
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
|
||||
"""
|
||||
Get stream generate nodes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
end_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.END.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of end nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in end_node_configs:
|
||||
# get generate route for stream output
|
||||
end_node_id = node_config['id']
|
||||
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
|
||||
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
|
||||
end_node_id=end_node_id,
|
||||
stream_node_ids=generate_nodes
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get end start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
if node_type in [
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value:
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
|
||||
|
||||
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
|
||||
if node_id not in stream_node_ids:
|
||||
continue
|
||||
|
||||
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
|
||||
|
||||
# get chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
|
||||
|
||||
if not route_chunk_node_execution:
|
||||
continue
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
# get value from outputs
|
||||
text = outputs.get('text')
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._text_chunk_to_stream_response(text)
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return False
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return False
|
||||
|
||||
node_id = event.metadata.get('node_id')
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
@ -119,7 +120,15 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
pass
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
|
||||
@ -72,7 +72,7 @@ class AppGenerateEntity(BaseModel):
|
||||
# app config
|
||||
app_config: AppConfig
|
||||
|
||||
inputs: dict[str, str]
|
||||
inputs: dict[str, Any]
|
||||
files: list[FileVar] = []
|
||||
user_id: str
|
||||
|
||||
|
||||
@ -9,9 +9,17 @@ from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk
|
||||
|
||||
|
||||
class StreamGenerateRoute(BaseModel):
|
||||
class WorkflowStreamGenerateNodes(BaseModel):
|
||||
"""
|
||||
StreamGenerateRoute entity
|
||||
WorkflowStreamGenerateNodes entity
|
||||
"""
|
||||
end_node_id: str
|
||||
stream_node_ids: list[str]
|
||||
|
||||
|
||||
class ChatflowStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
ChatflowStreamGenerateRoute entity
|
||||
"""
|
||||
answer_node_id: str
|
||||
generate_route: list[GenerateRouteChunk]
|
||||
@ -55,6 +63,8 @@ class WorkflowTaskState(TaskState):
|
||||
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
latest_node_execution_info: Optional[NodeExecutionInfo] = None
|
||||
|
||||
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
@ -62,7 +72,7 @@ class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
usage: LLMUsage
|
||||
|
||||
current_stream_generate_state: Optional[StreamGenerateRoute] = None
|
||||
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
|
||||
|
||||
|
||||
class StreamEvent(Enum):
|
||||
@ -236,6 +246,24 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"created_at": self.data.created_at,
|
||||
"extras": {}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class NodeFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
@ -266,6 +294,31 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"process_data": None,
|
||||
"outputs": None,
|
||||
"status": self.data.status,
|
||||
"error": None,
|
||||
"elapsed_time": self.data.elapsed_time,
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": []
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
|
||||
@ -118,7 +118,8 @@ class MessageCycleManage:
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
|
||||
@ -1,13 +1,20 @@
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import Literal, Optional
|
||||
|
||||
from httpx import post
|
||||
from httpx import get, post
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from config import get_env
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.javascript_transformer import NodeJsTemplateTransformer
|
||||
from core.helper.code_executor.jina2_transformer import Jinja2TemplateTransformer
|
||||
from core.helper.code_executor.python_transformer import PythonTemplateTransformer
|
||||
from core.helper.code_executor.jinja2_transformer import Jinja2TemplateTransformer
|
||||
from core.helper.code_executor.python_transformer import PYTHON_STANDARD_PACKAGES, PythonTemplateTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Code Executor
|
||||
CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
|
||||
@ -27,10 +34,34 @@ class CodeExecutionResponse(BaseModel):
|
||||
message: str
|
||||
data: Data
|
||||
|
||||
class CodeLanguage(str, Enum):
|
||||
PYTHON3 = 'python3'
|
||||
JINJA2 = 'jinja2'
|
||||
JAVASCRIPT = 'javascript'
|
||||
|
||||
|
||||
class CodeExecutor:
|
||||
dependencies_cache = {}
|
||||
dependencies_cache_lock = Lock()
|
||||
|
||||
code_template_transformers = {
|
||||
CodeLanguage.PYTHON3: PythonTemplateTransformer,
|
||||
CodeLanguage.JINJA2: Jinja2TemplateTransformer,
|
||||
CodeLanguage.JAVASCRIPT: NodeJsTemplateTransformer,
|
||||
}
|
||||
|
||||
code_language_to_running_language = {
|
||||
CodeLanguage.JAVASCRIPT: 'nodejs',
|
||||
CodeLanguage.JINJA2: CodeLanguage.PYTHON3,
|
||||
CodeLanguage.PYTHON3: CodeLanguage.PYTHON3,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def execute_code(cls, language: Literal['python3', 'javascript', 'jinja2'], preload: str, code: str) -> str:
|
||||
def execute_code(cls,
|
||||
language: Literal['python3', 'javascript', 'jinja2'],
|
||||
preload: str,
|
||||
code: str,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> str:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
@ -44,13 +75,15 @@ class CodeExecutor:
|
||||
}
|
||||
|
||||
data = {
|
||||
'language': 'python3' if language == 'jinja2' else
|
||||
'nodejs' if language == 'javascript' else
|
||||
'python3' if language == 'python3' else None,
|
||||
'language': cls.code_language_to_running_language.get(language),
|
||||
'code': code,
|
||||
'preload': preload
|
||||
'preload': preload,
|
||||
'enable_network': True
|
||||
}
|
||||
|
||||
if dependencies:
|
||||
data['dependencies'] = [dependency.dict() for dependency in dependencies]
|
||||
|
||||
try:
|
||||
response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
|
||||
if response.status_code == 503:
|
||||
@ -78,7 +111,7 @@ class CodeExecutor:
|
||||
return response.data.stdout
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict) -> dict:
|
||||
def execute_workflow_code_template(cls, language: Literal['python3', 'javascript', 'jinja2'], code: str, inputs: dict, dependencies: Optional[list[CodeDependency]] = None) -> dict:
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
@ -86,21 +119,67 @@ class CodeExecutor:
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
template_transformer = None
|
||||
if language == 'python3':
|
||||
template_transformer = PythonTemplateTransformer
|
||||
elif language == 'jinja2':
|
||||
template_transformer = Jinja2TemplateTransformer
|
||||
elif language == 'javascript':
|
||||
template_transformer = NodeJsTemplateTransformer
|
||||
else:
|
||||
raise CodeExecutionException('Unsupported language')
|
||||
template_transformer = cls.code_template_transformers.get(language)
|
||||
if not template_transformer:
|
||||
raise CodeExecutionException(f'Unsupported language {language}')
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
runner, preload, dependencies = template_transformer.transform_caller(code, inputs, dependencies)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
response = cls.execute_code(language, preload, runner, dependencies)
|
||||
except CodeExecutionException as e:
|
||||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
||||
return template_transformer.transform_response(response)
|
||||
|
||||
@classmethod
|
||||
def list_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]:
|
||||
with cls.dependencies_cache_lock:
|
||||
if language in cls.dependencies_cache:
|
||||
# check expiration
|
||||
dependencies = cls.dependencies_cache[language]
|
||||
if dependencies['expiration'] > time.time():
|
||||
return dependencies['data']
|
||||
# remove expired cache
|
||||
del cls.dependencies_cache[language]
|
||||
|
||||
dependencies = cls._get_dependencies(language)
|
||||
with cls.dependencies_cache_lock:
|
||||
cls.dependencies_cache[language] = {
|
||||
'data': dependencies,
|
||||
'expiration': time.time() + 60
|
||||
}
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _get_dependencies(cls, language: Literal['python3']) -> list[CodeDependency]:
|
||||
"""
|
||||
List dependencies
|
||||
"""
|
||||
url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'dependencies'
|
||||
|
||||
headers = {
|
||||
'X-Api-Key': CODE_EXECUTION_API_KEY
|
||||
}
|
||||
|
||||
running_language = cls.code_language_to_running_language.get(language)
|
||||
if isinstance(running_language, Enum):
|
||||
running_language = running_language.value
|
||||
|
||||
data = {
|
||||
'language': running_language,
|
||||
}
|
||||
|
||||
try:
|
||||
response = get(str(url), params=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f'Failed to list dependencies, got status code {response.status_code}, please check if the sandbox service is running')
|
||||
response = response.json()
|
||||
dependencies = response.get('data', {}).get('dependencies', [])
|
||||
return [
|
||||
CodeDependency(**dependency) for dependency in dependencies if dependency.get('name') not in PYTHON_STANDARD_PACKAGES
|
||||
]
|
||||
except Exception as e:
|
||||
logger.exception(f'Failed to list dependencies: {e}')
|
||||
return []
|
||||
6
api/core/helper/code_executor/entities.py
Normal file
6
api/core/helper/code_executor/entities.py
Normal file
@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeDependency(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
NODEJS_RUNNER = """// declare main function here
|
||||
@ -22,7 +24,8 @@ NODEJS_PRELOAD = """"""
|
||||
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
def transform_caller(cls, code: str, inputs: dict,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
|
||||
"""
|
||||
Transform code to python runner
|
||||
:param code: code
|
||||
@ -37,7 +40,7 @@ class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
runner = NODEJS_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, NODEJS_PRELOAD
|
||||
return runner, NODEJS_PRELOAD, []
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
|
||||
17
api/core/helper/code_executor/jinja2_formatter.py
Normal file
17
api/core/helper/code_executor/jinja2_formatter.py
Normal file
@ -0,0 +1,17 @@
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
|
||||
|
||||
class Jinja2Formatter:
|
||||
@classmethod
|
||||
def format(cls, template: str, inputs: str) -> str:
|
||||
"""
|
||||
Format template
|
||||
:param template: template
|
||||
:param inputs: inputs
|
||||
:return:
|
||||
"""
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language='jinja2', code=template, inputs=inputs
|
||||
)
|
||||
|
||||
return result['result']
|
||||
@ -1,10 +1,16 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.python_transformer import PYTHON_STANDARD_PACKAGES
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """
|
||||
import jinja2
|
||||
from json import loads
|
||||
from base64 import b64decode
|
||||
|
||||
template = jinja2.Template('''{{code}}''')
|
||||
|
||||
@ -12,7 +18,8 @@ def main(**inputs):
|
||||
return template.render(**inputs)
|
||||
|
||||
# execute main function, and return the result
|
||||
output = main(**{{inputs}})
|
||||
inputs = b64decode('{{inputs}}').decode('utf-8')
|
||||
output = main(**loads(inputs))
|
||||
|
||||
result = f'''<<RESULT>>{output}<<RESULT>>'''
|
||||
|
||||
@ -39,6 +46,7 @@ JINJA2_PRELOAD_TEMPLATE = """{% set fruits = ['Apple'] %}
|
||||
|
||||
JINJA2_PRELOAD = f"""
|
||||
import jinja2
|
||||
from base64 import b64decode
|
||||
|
||||
def _jinja2_preload_():
|
||||
# prepare jinja2 environment, load template and render before to avoid sandbox issue
|
||||
@ -50,9 +58,11 @@ if __name__ == '__main__':
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
def transform_caller(cls, code: str, inputs: dict,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
|
||||
"""
|
||||
Transform code to python runner
|
||||
:param code: code
|
||||
@ -60,11 +70,25 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
:return:
|
||||
"""
|
||||
|
||||
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
|
||||
|
||||
# transform jinja2 template to python code
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', json.dumps(inputs, indent=4, ensure_ascii=False))
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, JINJA2_PRELOAD
|
||||
if not dependencies:
|
||||
dependencies = []
|
||||
|
||||
# add native packages and jinja2
|
||||
for package in PYTHON_STANDARD_PACKAGES.union(['jinja2']):
|
||||
dependencies.append(CodeDependency(name=package, version=''))
|
||||
|
||||
# deduplicate
|
||||
dependencies = list({
|
||||
dep.name: dep for dep in dependencies if dep.name
|
||||
}.values())
|
||||
|
||||
return runner, JINJA2_PRELOAD, dependencies
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
@ -1,17 +1,24 @@
|
||||
import json
|
||||
import re
|
||||
from base64 import b64encode
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
PYTHON_RUNNER = """# declare main function here
|
||||
{{code}}
|
||||
|
||||
from json import loads, dumps
|
||||
from base64 import b64decode
|
||||
|
||||
# execute main function, and return the result
|
||||
# inputs is a dict, and it
|
||||
output = main(**{{inputs}})
|
||||
inputs = b64decode('{{inputs}}').decode('utf-8')
|
||||
output = main(**json.loads(inputs))
|
||||
|
||||
# convert output to json and print
|
||||
output = json.dumps(output, indent=4)
|
||||
output = dumps(output, indent=4)
|
||||
|
||||
result = f'''<<RESULT>>
|
||||
{output}
|
||||
@ -20,32 +27,17 @@ result = f'''<<RESULT>>
|
||||
print(result)
|
||||
"""
|
||||
|
||||
PYTHON_PRELOAD = """
|
||||
# prepare general imports
|
||||
import json
|
||||
import datetime
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import binascii
|
||||
import collections
|
||||
import functools
|
||||
import operator
|
||||
import itertools
|
||||
"""
|
||||
PYTHON_PRELOAD = """"""
|
||||
|
||||
PYTHON_STANDARD_PACKAGES = set([
|
||||
'json', 'datetime', 'math', 'random', 're', 'string', 'sys', 'time', 'traceback', 'uuid', 'os', 'base64',
|
||||
'hashlib', 'hmac', 'binascii', 'collections', 'functools', 'operator', 'itertools', 'uuid',
|
||||
])
|
||||
|
||||
class PythonTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
def transform_caller(cls, code: str, inputs: dict,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
|
||||
"""
|
||||
Transform code to python runner
|
||||
:param code: code
|
||||
@ -54,13 +46,24 @@ class PythonTemplateTransformer(TemplateTransformer):
|
||||
"""
|
||||
|
||||
# transform inputs to json string
|
||||
inputs_str = json.dumps(inputs, indent=4, ensure_ascii=False)
|
||||
inputs_str = b64encode(json.dumps(inputs, ensure_ascii=False).encode()).decode('utf-8')
|
||||
|
||||
# replace code and inputs
|
||||
runner = PYTHON_RUNNER.replace('{{code}}', code)
|
||||
runner = runner.replace('{{inputs}}', inputs_str)
|
||||
|
||||
return runner, PYTHON_PRELOAD
|
||||
# add standard packages
|
||||
if dependencies is None:
|
||||
dependencies = []
|
||||
|
||||
for package in PYTHON_STANDARD_PACKAGES:
|
||||
if package not in dependencies:
|
||||
dependencies.append(CodeDependency(name=package, version=''))
|
||||
|
||||
# deduplicate
|
||||
dependencies = list({dep.name: dep for dep in dependencies if dep.name}.values())
|
||||
|
||||
return runner, PYTHON_PRELOAD, dependencies
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.code_executor.entities import CodeDependency
|
||||
|
||||
|
||||
class TemplateTransformer(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def transform_caller(cls, code: str, inputs: dict) -> tuple[str, str]:
|
||||
def transform_caller(cls, code: str, inputs: dict,
|
||||
dependencies: Optional[list[CodeDependency]] = None) -> tuple[str, str, list[CodeDependency]]:
|
||||
"""
|
||||
Transform code to python runner
|
||||
:param code: code
|
||||
|
||||
@ -11,12 +11,13 @@ class ToolParameterCacheType(Enum):
|
||||
|
||||
class ToolParameterCache:
|
||||
def __init__(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
cache_type: ToolParameterCacheType
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
cache_type: ToolParameterCacheType,
|
||||
identity_id: str
|
||||
):
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
|
||||
@ -26,4 +26,6 @@
|
||||
- yi
|
||||
- openllm
|
||||
- localai
|
||||
- volcengine_maas
|
||||
- openai_api_compatible
|
||||
- deepseek
|
||||
|
||||
@ -482,6 +482,158 @@ LLM_BASE_MODELS = [
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-turbo',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.001,
|
||||
output=0.003,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-turbo-2024-04-09',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.001,
|
||||
output=0.003,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-vision-preview',
|
||||
entity=AIModelEntity(
|
||||
|
||||
@ -99,6 +99,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo-2024-04-09
|
||||
value: gpt-4-turbo-2024-04-09
|
||||
|
||||
@ -70,6 +70,10 @@ provider_credential_schema:
|
||||
label:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
- value: ap-southeast-2
|
||||
label:
|
||||
en_US: Asia Pacific (Sydney)
|
||||
zh_Hans: 亚太地区 (悉尼)
|
||||
- variable: model_for_validation
|
||||
required: false
|
||||
label:
|
||||
|
||||
@ -8,5 +8,10 @@
|
||||
- anthropic.claude-3-haiku-v1:0
|
||||
- cohere.command-light-text-v14
|
||||
- cohere.command-text-v14
|
||||
- meta.llama3-8b-instruct-v1:0
|
||||
- meta.llama3-70b-instruct-v1:0
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
- mistral.mistral-large-2402-v1:0
|
||||
- mistral.mixtral-8x7b-instruct-v0:1
|
||||
- mistral.mistral-7b-instruct-v0:2
|
||||
|
||||
@ -370,29 +370,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prefix = model.split('.')[0]
|
||||
|
||||
model_name = model.split('.')[1]
|
||||
if isinstance(messages, str):
|
||||
prompt = messages
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(messages, prefix)
|
||||
prompt = self._convert_messages_to_prompt(messages, prefix, model_name)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Google model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message, model_prefix)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -432,7 +417,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str:
|
||||
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
@ -446,9 +431,21 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
|
||||
elif model_prefix == "meta":
|
||||
human_prompt_prefix = "\n[INST]"
|
||||
# LLAMA3
|
||||
if model_name.startswith("llama3"):
|
||||
human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
||||
human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
else:
|
||||
# LLAMA2
|
||||
human_prompt_prefix = "\n[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = ""
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
human_prompt_prefix = "<s>[INST]"
|
||||
human_prompt_postfix = "[\\INST]\n"
|
||||
ai_prompt = ""
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
|
||||
elif model_prefix == "amazon":
|
||||
human_prompt_prefix = "\n\nUser:"
|
||||
@ -473,11 +470,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str:
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:param model_name: specific model name.Optional,just to distinguish llama2 and llama3
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
if not messages:
|
||||
@ -488,18 +486,20 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
messages.append(AssistantPromptMessage(content=""))
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message, model_prefix)
|
||||
self._convert_one_message_to_text(message, model_prefix, model_name)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
||||
def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
|
||||
"""
|
||||
Create payload for bedrock api call depending on model provider
|
||||
"""
|
||||
payload = dict()
|
||||
model_prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
|
||||
if model_prefix == "amazon":
|
||||
payload["textGenerationConfig"] = { **model_parameters }
|
||||
@ -519,6 +519,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
|
||||
if model_parameters.get("countPenalty"):
|
||||
payload["countPenalty"] = {model_parameters.get("countPenalty")}
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
payload["temperature"] = model_parameters.get("temperature")
|
||||
payload["top_p"] = model_parameters.get("top_p")
|
||||
payload["max_tokens"] = model_parameters.get("max_tokens")
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
||||
payload["stop"] = stop[:10] if stop else []
|
||||
|
||||
elif model_prefix == "anthropic":
|
||||
payload = { **model_parameters }
|
||||
@ -532,7 +539,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
elif model_prefix == "meta":
|
||||
payload = { **model_parameters }
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
|
||||
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix}")
|
||||
@ -567,7 +574,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
)
|
||||
|
||||
model_prefix = model.split('.')[0]
|
||||
payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream)
|
||||
payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream)
|
||||
|
||||
# need workaround for ai21 models which doesn't support streaming
|
||||
if stream and model_prefix != "ai21":
|
||||
@ -648,6 +655,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
output = response_body.get("generation").strip('\n')
|
||||
prompt_tokens = response_body.get("prompt_token_count")
|
||||
completion_tokens = response_body.get("generation_token_count")
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
output = response_body.get("outputs")[0].get("text")
|
||||
prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count')
|
||||
completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count')
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
@ -731,6 +743,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
content_delta = payload.get("text")
|
||||
finish_reason = payload.get("finish_reason")
|
||||
|
||||
elif model_prefix == "mistral":
|
||||
content_delta = payload.get('outputs')[0].get("text")
|
||||
finish_reason = payload.get('outputs')[0].get("stop_reason")
|
||||
|
||||
elif model_prefix == "meta":
|
||||
content_delta = payload.get("generation").strip('\n')
|
||||
finish_reason = payload.get("stop_reason")
|
||||
|
||||
@ -0,0 +1,23 @@
|
||||
model: meta.llama3-70b-instruct-v1:0
|
||||
label:
|
||||
en_US: Llama 3 Instruct 70B
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.00265'
|
||||
output: '0.0035'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
||||
@ -0,0 +1,23 @@
|
||||
model: meta.llama3-8b-instruct-v1:0
|
||||
label:
|
||||
en_US: Llama 3 Instruct 8B
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_gen_len
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 2048
|
||||
pricing:
|
||||
input: '0.0004'
|
||||
output: '0.0006'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,39 @@
|
||||
model: mistral.mistral-7b-instruct-v0:2
|
||||
label:
|
||||
en_US: Mistral 7B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 0.9
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 50
|
||||
max: 200
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00015'
|
||||
output: '0.0002'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
||||
@ -0,0 +1,27 @@
|
||||
model: mistral.mistral-large-2402-v1:0
|
||||
label:
|
||||
en_US: Mistral Large
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.7
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -0,0 +1,39 @@
|
||||
model: mistral.mixtral-8x7b-instruct-v0:1
|
||||
label:
|
||||
en_US: Mixtral 8X7B Instruct
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
required: false
|
||||
default: 0.9
|
||||
- name: top_k
|
||||
use_template: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 50
|
||||
max: 200
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00045'
|
||||
output: '0.0007'
|
||||
unit: '0.00001'
|
||||
currency: USD
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from os.path import join
|
||||
from typing import Optional, cast
|
||||
|
||||
from httpx import Timeout
|
||||
@ -19,6 +18,7 @@ from openai import (
|
||||
)
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
@ -265,7 +265,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
|
||||
client_kwargs = {
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"api_key": "1",
|
||||
"base_url": join(credentials['api_base'], 'v1')
|
||||
"base_url": str(URL(credentials['api_base']) / 'v1')
|
||||
}
|
||||
|
||||
return client_kwargs
|
||||
|
||||
@ -32,6 +32,15 @@ provider_credential_schema:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
show_on: [ ]
|
||||
- variable: base_url
|
||||
label:
|
||||
zh_Hans: API Base
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1
|
||||
en_US: Enter your API Base, e.g. https://api.cohere.ai/v1
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
@ -70,3 +79,12 @@ model_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: base_url
|
||||
label:
|
||||
zh_Hans: API Base
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base,如 https://api.cohere.ai/v1
|
||||
en_US: Enter your API Base, e.g. https://api.cohere.ai/v1
|
||||
|
||||
@ -173,7 +173,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
|
||||
if stop:
|
||||
model_parameters['end_sequences'] = stop
|
||||
@ -233,7 +233,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return response
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict,
|
||||
response: Iterator[GenerateStreamedResponse],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -317,7 +318,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
|
||||
if stop:
|
||||
model_parameters['stop_sequences'] = stop
|
||||
@ -636,7 +637,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
:return: number of tokens
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
|
||||
response = client.tokenize(
|
||||
text=text,
|
||||
|
||||
@ -44,7 +44,7 @@ class CohereRerankModel(RerankModel):
|
||||
)
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
response = client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
|
||||
@ -141,7 +141,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
return []
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
|
||||
response = client.tokenize(
|
||||
text=text,
|
||||
@ -180,7 +180,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
client = cohere.Client(credentials.get('api_key'), base_url=credentials.get('base_url'))
|
||||
|
||||
# call embedding model
|
||||
response = client.embed(
|
||||
|
||||
@ -0,0 +1,22 @@
|
||||
<svg width="195.000000" height="41.359375" viewBox="0 0 195 41.3594" fill="none" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<desc>
|
||||
Created with Pixso.
|
||||
</desc>
|
||||
<defs>
|
||||
<clipPath id="clip30_2029">
|
||||
<rect id="_图层_1" width="134.577469" height="25.511124" transform="translate(60.422485 10.022217)" fill="white"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
<g clip-path="url(#clip30_2029)">
|
||||
<path id="path" d="M119.508 30.113L117.562 30.113L117.562 27.0967L119.508 27.0967C120.713 27.0967 121.931 26.7961 122.715 25.9614C123.5 25.1265 123.796 23.8464 123.796 22.5664C123.796 21.2864 123.512 20.0063 122.715 19.1716C121.919 18.3369 120.713 18.0364 119.508 18.0364C118.302 18.0364 117.085 18.3369 116.3 19.1716C115.515 20.0063 115.219 21.2864 115.219 22.5664L115.219 34.9551L111.806 34.9551L111.806 15.031L115.219 15.031L115.219 16.2998L115.845 16.2998C115.913 16.2219 115.981 16.1553 116.049 16.0884C116.903 15.3093 118.211 15.031 119.496 15.031C121.51 15.031 123.523 15.532 124.843 16.9233C126.162 18.3145 126.629 20.4517 126.629 22.5776C126.629 24.7036 126.151 26.8296 124.843 28.2319C123.535 29.6345 121.51 30.113 119.508 30.113Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||
<path id="path" d="M67.5664 15.5654L69.5117 15.5654L69.5117 18.5818L67.5664 18.5818C66.3606 18.5818 65.1434 18.8823 64.3585 19.717C63.5736 20.552 63.2778 21.832 63.2778 23.1121C63.2778 24.3921 63.5623 25.6721 64.3585 26.5068C65.1548 27.3418 66.3606 27.6423 67.5664 27.6423C68.7722 27.6423 69.9895 27.3418 70.7744 26.5068C71.5593 25.6721 71.8551 24.3921 71.8551 23.1121L71.8551 10.7124L75.2677 10.7124L75.2677 30.6475L71.8551 30.6475L71.8551 29.3787L71.2294 29.3787C71.1611 29.4565 71.0929 29.5234 71.0247 29.5901C70.1715 30.3691 68.8633 30.6475 67.5779 30.6475C65.5643 30.6475 63.5509 30.1467 62.2313 28.7554C60.9117 27.364 60.4453 25.2268 60.4453 23.1008C60.4453 20.9749 60.9231 18.8489 62.2313 17.4465C63.5509 16.0552 65.5643 15.5654 67.5664 15.5654Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||
<path id="path" d="M92.3881 22.845L92.3881 24.0581L83.299 24.0581L83.299 21.6428L89.328 21.6428C89.1914 20.7634 88.8729 19.9397 88.3042 19.3386C87.4851 18.4705 86.2224 18.1589 84.9711 18.1589C83.7198 18.1589 82.4572 18.4705 81.6381 19.3386C80.819 20.2068 80.5232 21.5315 80.5232 22.845C80.5232 24.1582 80.819 25.4939 81.6381 26.3511C82.4572 27.208 83.7198 27.531 84.9711 27.531C86.2224 27.531 87.4851 27.2192 88.3042 26.3511C88.418 26.2285 88.5203 26.095 88.6227 25.9614L91.9899 25.9614C91.6941 27.0078 91.2277 27.9539 90.5225 28.6885C89.1573 30.1243 87.0529 30.6475 84.9711 30.6475C82.8894 30.6475 80.7849 30.1355 79.4198 28.6885C78.0547 27.2415 77.5542 25.0376 77.5542 22.845C77.5542 20.6521 78.0433 18.437 79.4198 17.0012C80.7963 15.5654 82.8894 15.0422 84.9711 15.0422C87.0529 15.0422 89.1573 15.5542 90.5225 17.0012C91.8988 18.4482 92.3881 20.6521 92.3881 22.845Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||
<path id="path" d="M109.52 22.845L109.52 24.0581L100.431 24.0581L100.431 21.6428L106.46 21.6428C106.323 20.7634 106.005 19.9397 105.436 19.3386C104.617 18.4705 103.354 18.1589 102.103 18.1589C100.852 18.1589 99.5889 18.4705 98.7698 19.3386C97.9507 20.2068 97.6549 21.5315 97.6549 22.845C97.6549 24.1582 97.9507 25.4939 98.7698 26.3511C99.5889 27.208 100.852 27.531 102.103 27.531C103.354 27.531 104.617 27.2192 105.436 26.3511C105.55 26.2285 105.652 26.095 105.754 25.9614L109.122 25.9614C108.826 27.0078 108.359 27.9539 107.654 28.6885C106.289 30.1243 104.185 30.6475 102.103 30.6475C100.021 30.6475 97.9166 30.1355 96.5515 28.6885C95.1864 27.2415 94.6859 25.0376 94.6859 22.845C94.6859 20.6521 95.175 18.437 96.5515 17.0012C97.928 15.5654 100.021 15.0422 102.103 15.0422C104.185 15.0422 106.289 15.5542 107.654 17.0012C109.031 18.4482 109.52 20.6521 109.52 22.845Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||
<path id="path" d="M136.355 30.6475C138.437 30.6475 140.541 30.3469 141.906 29.49C143.271 28.6328 143.772 27.3306 143.772 26.0393C143.772 24.7483 143.282 23.4348 141.906 22.5889C140.541 21.7429 138.437 21.4312 136.355 21.4312C135.467 21.4312 134.648 21.3088 134.068 20.9861C133.488 20.6521 133.272 20.1511 133.272 19.6504C133.272 19.1494 133.477 18.6375 134.068 18.3147C134.648 17.9807 135.547 17.8694 136.434 17.8694C137.322 17.8694 138.22 17.9919 138.801 18.3147C139.381 18.6487 139.597 19.1494 139.597 19.6504L143.066 19.6504C143.066 18.3591 142.623 17.0457 141.383 16.2C140.143 15.354 138.243 15.0422 136.355 15.0422C134.466 15.0422 132.567 15.3428 131.327 16.2C130.087 17.0569 129.643 18.3591 129.643 19.6504C129.643 20.9414 130.087 22.2549 131.327 23.1008C132.567 23.9468 134.466 24.2585 136.355 24.2585C137.333 24.2585 138.414 24.3809 139.062 24.7036C139.711 25.0266 139.938 25.5386 139.938 26.0393C139.938 26.5403 139.711 27.0522 139.062 27.375C138.414 27.6978 137.424 27.8203 136.446 27.8203C135.467 27.8203 134.466 27.6978 133.829 27.375C133.192 27.0522 132.953 26.5403 132.953 26.0393L128.949 26.0393C128.949 27.3306 129.438 28.644 130.815 29.49C132.191 30.3359 134.273 30.6475 136.355 30.6475Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||
<path id="path" d="M160.903 22.845L160.903 24.0581L151.814 24.0581L151.814 21.6428L157.843 21.6428C157.707 20.7634 157.388 19.9397 156.82 19.3386C156 18.4705 154.738 18.1589 153.486 18.1589C152.235 18.1589 150.972 18.4705 150.153 19.3386C149.334 20.2068 149.039 21.5315 149.039 22.845C149.039 24.1582 149.334 25.4939 150.153 26.3511C150.972 27.208 152.235 27.531 153.486 27.531C154.738 27.531 156 27.2192 156.82 26.3511C156.933 26.2285 157.036 26.095 157.138 25.9614L160.505 25.9614C160.209 27.0078 159.743 27.9539 159.038 28.6885C157.673 30.1243 155.568 30.6475 153.486 30.6475C151.405 30.6475 149.3 30.1355 147.935 28.6885C146.57 27.2415 146.07 25.0376 146.07 22.845C146.07 20.6521 146.559 18.437 147.935 17.0012C149.312 15.5654 151.405 15.0422 153.486 15.0422C155.568 15.0422 157.673 15.5542 159.038 17.0012C160.414 18.4482 160.903 20.6521 160.903 22.845Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||
<path id="path" d="M178.035 22.845L178.035 24.0581L168.946 24.0581L168.946 21.6428L174.975 21.6428C174.839 20.7634 174.52 19.9397 173.951 19.3386C173.132 18.4705 171.87 18.1589 170.618 18.1589C169.367 18.1589 168.104 18.4705 167.285 19.3386C166.466 20.2068 166.17 21.5315 166.17 22.845C166.17 24.1582 166.466 25.4939 167.285 26.3511C168.104 27.208 169.367 27.531 170.618 27.531C171.87 27.531 173.132 27.2192 173.951 26.3511C174.065 26.2285 174.167 26.095 174.27 25.9614L177.637 25.9614C177.341 27.0078 176.875 27.9539 176.17 28.6885C174.804 30.1243 172.7 30.6475 170.618 30.6475C168.536 30.6475 166.432 30.1355 165.067 28.6885C163.702 27.2415 163.201 25.0376 163.201 22.845C163.201 20.6521 163.69 18.437 165.067 17.0012C166.443 15.5654 168.536 15.0422 170.618 15.0422C172.7 15.0422 174.804 15.5542 176.17 17.0012C177.546 18.4482 178.035 20.6521 178.035 22.845Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||
<rect id="rect" x="180.321533" y="10.022217" width="3.412687" height="20.625223" fill="#4D6BFE"/>
|
||||
<path id="polygon" d="M189.559 22.3772L195.155 30.6475L190.935 30.6475L185.338 22.3772L190.935 15.7322L195.155 15.7322L189.559 22.3772Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||
</g>
|
||||
<path id="path" d="M55.6128 3.47119C55.0175 3.17944 54.7611 3.73535 54.413 4.01782C54.2939 4.10889 54.1932 4.22729 54.0924 4.33667C53.2223 5.26587 52.2057 5.87646 50.8776 5.80347C48.9359 5.69409 47.2781 6.30469 45.8126 7.78979C45.5012 5.9585 44.4663 4.86499 42.8909 4.16357C42.0667 3.79907 41.2332 3.43457 40.6561 2.64185C40.2532 2.07715 40.1432 1.44849 39.9418 0.828857C39.8135 0.455322 39.6853 0.0725098 39.2548 0.00878906C38.7877 -0.0639648 38.6045 0.327637 38.4213 0.655762C37.6886 1.99512 37.4047 3.47119 37.4321 4.96533C37.4962 8.32739 38.9159 11.0059 41.7369 12.9102C42.0575 13.1289 42.1399 13.3474 42.0392 13.6665C41.8468 14.3225 41.6178 14.9602 41.4164 15.6162C41.2881 16.0354 41.0957 16.1265 40.647 15.9441C39.0991 15.2974 37.7618 14.3406 36.5803 13.1836C34.5745 11.2429 32.761 9.10181 30.4988 7.42529C29.9675 7.03345 29.4363 6.66919 28.8867 6.32275C26.5786 4.08154 29.189 2.24097 29.7935 2.02246C30.4254 1.79468 30.0133 1.01099 27.9708 1.02026C25.9283 1.0293 24.0599 1.71265 21.6786 2.62378C21.3306 2.7605 20.9641 2.8606 20.5886 2.94263C18.4271 2.53271 16.1831 2.44141 13.8384 2.70581C9.42371 3.19775 5.89758 5.28418 3.30554 8.84668C0.191406 13.1289 -0.54126 17.9941 0.356323 23.0691C1.29968 28.4172 4.02905 32.8452 8.22388 36.3076C12.5745 39.8972 17.5845 41.6558 23.2997 41.3186C26.771 41.1182 30.6361 40.6536 34.9958 36.9636C36.0948 37.5103 37.2489 37.7288 39.1632 37.8928C40.6378 38.0295 42.0575 37.8201 43.1565 37.5923C44.8784 37.2278 44.7594 35.6333 44.1366 35.3418C39.09 32.9912 40.1981 33.9478 39.1907 33.1733C41.7552 30.1394 45.6204 26.9868 47.1316 16.7732C47.2506 15.9624 47.1499 15.4521 47.1316 14.7961C47.1224 14.3953 47.214 14.2405 47.672 14.1948C48.9359 14.0491 50.1632 13.7029 51.2898 13.0833C54.5596 11.2976 55.8784 8.36377 56.1898 4.84692C56.2357 4.30933 56.1807 3.75342 55.6128 3.47119ZM27.119 35.123C22.2281 31.2783 19.856 30.0117 18.8759 30.0664C17.96 30.1211 18.1249 31.1689 18.3263 31.8523C18.537 32.5264 18.8118 32.9912 19.1964 33.5833C19.462 33.9751 19.6453 34.5581 18.9309 34.9956C17.3555 35.9705 14.6169 34.6675 14.4886 34.6038C11.3014 32.7268 8.63611 30.2485 6.75842 26.8594C4.94495 23.5974 3.89172 20.0989 3.71765 16.3633C3.67188 15.4614 3.9375 15.1423 4.83508 14.9785C6.0166 14.7598 7.23474 14.7141 8.41626 14.8872C13.408 15.6162 17.6577 17.8484 21.2206 21.3835C23.2539 23.397 24.7926 25.8025 26.3772 28.1531C28.0624 30.6494 29.8759 33.0276 32.184 34.9773C32.9991 35.6606 33.6494 36.1799 34.2722 36.5627C32.3947 36.7722 29.2622 36.8179 27.119 35.123ZM29.4637 20.0442C29.4637 19.6433 29.7843 19.3245 30.1874 19.3245C30.2789 19.3245 30.3613 19.3425 30.4346 19.3699C30.5354 19.4065 30.627 19.4612 30.7002 19.543C30.8285 19.6707 30.9017 19.8528 30.9017 20.0442C30.9017 20.4451 30.5812 20.7639 30.1782 20.7639C29.7751 20.7639 29.4637 20.4451 29.4637 20.0442ZM36.7452 23.7798C36.2781 23.9712 35.811 24.135 35.3622 24.1533C34.6661 24.1897 33.9059 23.9072 33.4938 23.561C32.8527 23.0234 32.3947 22.7229 32.2023 21.7844C32.1199 21.3835 32.1656 20.7639 32.239 20.4087C32.4038 19.6433 32.2206 19.1514 31.6803 18.7048C31.2406 18.3403 30.6819 18.2402 30.0682 18.2402C29.8392 18.2402 29.6287 18.1399 29.4729 18.0579C29.2164 17.9304 29.0059 17.6116 29.2073 17.2197C29.2714 17.0923 29.5829 16.7825 29.6561 16.7278C30.4896 16.2539 31.4513 16.4089 32.3397 16.7642C33.1641 17.1013 33.7869 17.7209 34.6844 18.5955C35.6003 19.6523 35.7651 19.9441 36.2872 20.7366C36.6995 21.3562 37.075 21.9939 37.3314 22.7229C37.4871 23.1785 37.2856 23.552 36.7452 23.7798Z" fill-rule="nonzero" fill="#4D6BFE"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 10 KiB |
@ -0,0 +1,3 @@
|
||||
<svg width="60" height="50" viewBox="0 0 60 50" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" xmlns:serif="http://www.serif.com/" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2;">
|
||||
<path id="path" d="M55.613,3.471C55.018,3.179 54.761,3.735 54.413,4.018C54.294,4.109 54.193,4.227 54.092,4.337C53.222,5.266 52.206,5.876 50.878,5.803C48.936,5.694 47.278,6.305 45.813,7.79C45.501,5.959 44.466,4.865 42.891,4.164C42.067,3.799 41.233,3.435 40.656,2.642C40.253,2.077 40.143,1.448 39.942,0.829C39.814,0.455 39.685,0.073 39.255,0.009C38.788,-0.064 38.605,0.328 38.421,0.656C37.689,1.995 37.405,3.471 37.432,4.965C37.496,8.327 38.916,11.006 41.737,12.91C42.058,13.129 42.14,13.347 42.039,13.666C41.847,14.323 41.618,14.96 41.416,15.616C41.288,16.035 41.096,16.127 40.647,15.944C39.099,15.297 37.762,14.341 36.58,13.184C34.575,11.243 32.761,9.102 30.499,7.425C29.968,7.033 29.436,6.669 28.887,6.323C26.579,4.082 29.189,2.241 29.794,2.022C30.425,1.795 30.013,1.011 27.971,1.02C25.928,1.029 24.06,1.713 21.679,2.624C21.331,2.761 20.964,2.861 20.589,2.943C18.427,2.533 16.183,2.441 13.838,2.706C9.424,3.198 5.898,5.284 3.306,8.847C0.191,13.129 -0.541,17.994 0.356,23.069C1.3,28.417 4.029,32.845 8.224,36.308C12.575,39.897 17.584,41.656 23.3,41.319C26.771,41.118 30.636,40.654 34.996,36.964C36.095,37.51 37.249,37.729 39.163,37.893C40.638,38.03 42.058,37.82 43.157,37.592C44.878,37.228 44.759,35.633 44.137,35.342C39.09,32.991 40.198,33.948 39.191,33.173C41.755,30.139 45.62,26.987 47.132,16.773C47.251,15.962 47.15,15.452 47.132,14.796C47.122,14.395 47.214,14.241 47.672,14.195C48.936,14.049 50.163,13.703 51.29,13.083C54.56,11.298 55.878,8.364 56.19,4.847C56.236,4.309 56.181,3.753 55.613,3.471ZM27.119,35.123C22.228,31.278 19.856,30.012 18.876,30.066C17.96,30.121 18.125,31.169 18.326,31.852C18.537,32.526 18.812,32.991 19.196,33.583C19.462,33.975 19.645,34.558 18.931,34.996C17.356,35.971 14.617,34.667 14.489,34.604C11.301,32.727 8.636,30.249 6.758,26.859C4.945,23.597 3.892,20.099 3.718,16.363C3.672,15.461 3.938,15.142 4.835,14.979C6.017,14.76 7.235,14.714 8.416,14.887C13.408,15.616 17.658,17.848 21.221,21.384C23.254,23.397 24.793,25.802 26.377,28.153C28.062,30.649 29.876,33.028 32.184,34.977C32.999,35.661 33.649,36.18 34.272,36.563C32.395,36.772 29.262,36.818 27.119,35.123ZM29.464,20.044C29.464,19.643 29.784,19.325 30.187,19.325C30.279,19.325 30.361,19.343 30.435,19.37C30.535,19.407 30.627,19.461 30.7,19.543C30.828,19.671 30.902,19.853 30.902,20.044C30.902,20.445 30.581,20.764 30.178,20.764C29.775,20.764 29.464,20.445 29.464,20.044ZM36.745,23.78C36.278,23.971 35.811,24.135 35.362,24.153C34.666,24.19 33.906,23.907 33.494,23.561C32.853,23.023 32.395,22.723 32.202,21.784C32.12,21.384 32.166,20.764 32.239,20.409C32.404,19.643 32.221,19.151 31.68,18.705C31.241,18.34 30.682,18.24 30.068,18.24C29.839,18.24 29.629,18.14 29.473,18.058C29.216,17.93 29.006,17.612 29.207,17.22C29.271,17.092 29.583,16.783 29.656,16.728C30.49,16.254 31.451,16.409 32.34,16.764C33.164,17.101 33.787,17.721 34.684,18.596C35.6,19.652 35.765,19.944 36.287,20.737C36.7,21.356 37.075,21.994 37.331,22.723C37.487,23.179 37.286,23.552 36.745,23.78Z" style="fill:rgb(77,107,254);fill-rule:nonzero;"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.2 KiB |
33
api/core/model_runtime/model_providers/deepseek/deepseek.py
Normal file
33
api/core/model_runtime/model_providers/deepseek/deepseek.py
Normal file
@ -0,0 +1,33 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class DeepSeekProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `deepseek-chat` model for validate,
|
||||
# no matter what model you pass in, text completion model or chat model
|
||||
model_instance.validate_credentials(
|
||||
model='deepseek-chat',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@ -0,0 +1,41 @@
|
||||
provider: deepseek
|
||||
label:
|
||||
en_US: deepseek
|
||||
zh_Hans: 深度求索
|
||||
description:
|
||||
en_US: Models provided by deepseek, such as deepseek-chat、deepseek-coder.
|
||||
zh_Hans: 深度求索提供的模型,例如 deepseek-chat、deepseek-coder 。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#c0cdff"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from deepseek
|
||||
zh_Hans: 从深度求索获取 API Key
|
||||
url:
|
||||
en_US: https://platform.deepseek.com/api_keys
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
zh_Hans: 自定义 API endpoint 地址
|
||||
en_US: Custom API endpoint URL
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: Base URL, e.g. https://api.deepseek.com/v1 or https://api.deepseek.com
|
||||
en_US: Base URL, e.g. https://api.deepseek.com/v1 or https://api.deepseek.com
|
||||
@ -0,0 +1,2 @@
|
||||
- deepseek-chat
|
||||
- deepseek-coder
|
||||
@ -0,0 +1,64 @@
|
||||
model: deepseek-chat
|
||||
label:
|
||||
zh_Hans: deepseek-chat
|
||||
en_US: deepseek-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
|
||||
en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 32000
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.01
|
||||
max: 1.00
|
||||
help:
|
||||
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
||||
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
||||
- name: logprobs
|
||||
help:
|
||||
zh_Hans: 是否返回所输出 token 的对数概率。如果为 true,则在 message 的 content 中返回每个输出 token 的对数概率。
|
||||
en_US: Whether to return the log probability of the output token. If true, returns the log probability of each output token in the content of message .
|
||||
type: boolean
|
||||
- name: top_logprobs
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
max: 20
|
||||
help:
|
||||
zh_Hans: 一个介于 0 到 20 之间的整数 N,指定每个输出位置返回输出概率 top N 的 token,且返回这些 token 的对数概率。指定此参数时,logprobs 必须为 true。
|
||||
en_US: An integer N between 0 and 20, specifying that each output position returns the top N tokens with output probability, and returns the logarithmic probability of these tokens. When specifying this parameter, logprobs must be true.
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 0
|
||||
min: -2.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。
|
||||
en_US: A number between -2.0 and 2.0. If the value is positive, new tokens are penalized based on their frequency of occurrence in existing text, reducing the likelihood that the model will repeat the same content.
|
||||
pricing:
|
||||
input: '1'
|
||||
output: '2'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,26 @@
|
||||
model: deepseek-coder
|
||||
label:
|
||||
zh_Hans: deepseek-coder
|
||||
en_US: deepseek-coder
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 32000
|
||||
default: 1024
|
||||
113
api/core/model_runtime/model_providers/deepseek/llm/llm.py
Normal file
113
api/core/model_runtime/model_providers/deepseek/llm/llm.py
Normal file
@ -0,0 +1,113 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import tiktoken
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
|
||||
|
||||
|
||||
class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_string(self, model: str, text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model with tiktoken package.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
num_tokens = len(encoding.encode(text))
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
# refactored from openai model runtime, use cl100k_base for calculate token number
|
||||
def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
|
||||
if key == "tool_calls":
|
||||
for tool_call in value:
|
||||
for t_key, t_value in tool_call.items():
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
if t_key == "function":
|
||||
for f_key, f_value in t_value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(t_key))
|
||||
num_tokens += len(encoding.encode(t_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if tools:
|
||||
num_tokens += self._num_tokens_for_tools(encoding, tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['openai_api_key']=credentials['api_key']
|
||||
if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "":
|
||||
credentials['openai_api_base']='https://api.deepseek.com'
|
||||
else:
|
||||
parsed_url = urlparse(credentials['endpoint_url'])
|
||||
credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
@ -19,7 +19,7 @@ class GroqProvider(ModelProvider):
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='llama2-70b-4096',
|
||||
model='llama3-8b-8192',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
||||
@ -0,0 +1,25 @@
|
||||
model: llama3-70b-8192
|
||||
label:
|
||||
zh_Hans: Llama-3-70B-8192
|
||||
en_US: Llama-3-70B-8192
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.05'
|
||||
output: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,25 @@
|
||||
model: llama3-8b-8192
|
||||
label:
|
||||
zh_Hans: Llama-3-8B-8192
|
||||
en_US: Llama-3-8B-8192
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.59'
|
||||
output: '0.79'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -19,6 +19,7 @@ supported_model_types:
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
@ -29,3 +30,40 @@ provider_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: base_url
|
||||
label:
|
||||
zh_Hans: 服务器 URL
|
||||
en_US: Base URL
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: Base URL, e.g. https://api.jina.ai/v1
|
||||
en_US: Base URL, e.g. https://api.jina.ai/v1
|
||||
default: 'https://api.jina.ai/v1'
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 上下文大小
|
||||
en_US: Context size
|
||||
placeholder:
|
||||
zh_Hans: 输入上下文大小
|
||||
en_US: Enter context size
|
||||
required: false
|
||||
type: text-input
|
||||
default: '8192'
|
||||
|
||||
@ -2,6 +2,8 @@ from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
@ -38,9 +40,13 @@ class JinaRerankModel(RerankModel):
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
base_url = credentials.get('base_url', 'https://api.jina.ai/v1')
|
||||
if base_url.endswith('/'):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
"https://api.jina.ai/v1/rerank",
|
||||
base_url + '/rerank',
|
||||
json={
|
||||
"model": model,
|
||||
"query": query,
|
||||
@ -103,3 +109,19 @@ class JinaRerankModel(RerankModel):
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
|
||||
}
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -4,7 +4,8 @@ from typing import Optional
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
@ -23,8 +24,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Jina text embedding model.
|
||||
"""
|
||||
api_base: str = 'https://api.jina.ai/v1/embeddings'
|
||||
models: list[str] = ['jina-embeddings-v2-base-en', 'jina-embeddings-v2-small-en', 'jina-embeddings-v2-base-zh', 'jina-embeddings-v2-base-de']
|
||||
api_base: str = 'https://api.jina.ai/v1'
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
@ -39,11 +39,14 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
:return: embeddings result
|
||||
"""
|
||||
api_key = credentials['api_key']
|
||||
if model not in self.models:
|
||||
raise InvokeBadRequestError('Invalid model name')
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError('api_key is required')
|
||||
url = self.api_base
|
||||
|
||||
base_url = credentials.get('base_url', self.api_base)
|
||||
if base_url.endswith('/'):
|
||||
base_url = base_url[:-1]
|
||||
|
||||
url = base_url + '/embeddings'
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + api_key,
|
||||
'Content-Type': 'application/json'
|
||||
@ -70,7 +73,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
elif response.status_code == 500:
|
||||
raise InvokeServerUnavailableError(msg)
|
||||
else:
|
||||
raise InvokeError(msg)
|
||||
raise InvokeBadRequestError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
@ -118,8 +121,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
except InvokeAuthorizationError:
|
||||
raise CredentialsValidateFailedError('Invalid api key')
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
@ -137,7 +140,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
KeyError,
|
||||
InvokeBadRequestError
|
||||
]
|
||||
}
|
||||
|
||||
@ -170,3 +174,19 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
|
||||
}
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 516 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
29
api/core/model_runtime/model_providers/leptonai/leptonai.py
Normal file
29
api/core/model_runtime/model_providers/leptonai/leptonai.py
Normal file
@ -0,0 +1,29 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LeptonAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='llama2-7b',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@ -0,0 +1,29 @@
|
||||
provider: leptonai
|
||||
label:
|
||||
zh_Hans: Lepton AI
|
||||
en_US: Lepton AI
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#F5F5F4"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from Lepton AI
|
||||
zh_Hans: 从 Lepton AI 获取 API Key
|
||||
url:
|
||||
en_US: https://dashboard.lepton.ai
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
@ -0,0 +1,6 @@
|
||||
- gemma-7b
|
||||
- mistral-7b
|
||||
- mixtral-8x7b
|
||||
- llama2-7b
|
||||
- llama2-13b
|
||||
- llama3-70b
|
||||
@ -0,0 +1,20 @@
|
||||
model: gemma-7b
|
||||
label:
|
||||
zh_Hans: gemma-7b
|
||||
en_US: gemma-7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
@ -0,0 +1,20 @@
|
||||
model: llama2-13b
|
||||
label:
|
||||
zh_Hans: llama2-13b
|
||||
en_US: llama2-13b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
@ -0,0 +1,20 @@
|
||||
model: llama2-7b
|
||||
label:
|
||||
zh_Hans: llama2-7b
|
||||
en_US: llama2-7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
@ -0,0 +1,20 @@
|
||||
model: llama3-70b
|
||||
label:
|
||||
zh_Hans: llama3-70b
|
||||
en_US: llama3-70b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
34
api/core/model_runtime/model_providers/leptonai/llm/llm.py
Normal file
34
api/core/model_runtime/model_providers/leptonai/llm/llm.py
Normal file
@ -0,0 +1,34 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class LeptonAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
MODEL_PREFIX_MAP = {
|
||||
'llama2-7b': 'llama2-7b',
|
||||
'gemma-7b': 'gemma-7b',
|
||||
'mistral-7b': 'mistral-7b',
|
||||
'mixtral-8x7b': 'mixtral-8x7b',
|
||||
'llama3-70b': 'llama3-70b',
|
||||
'llama2-13b': 'llama2-13b',
|
||||
}
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials, model)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials, model)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@classmethod
|
||||
def _add_custom_parameters(cls, credentials: dict, model: str) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = f'https://{cls.MODEL_PREFIX_MAP[model]}.lepton.run/api/v1'
|
||||
|
||||
@ -0,0 +1,20 @@
|
||||
model: mistral-7b
|
||||
label:
|
||||
zh_Hans: mistral-7b
|
||||
en_US: mistral-7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 8192
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
@ -0,0 +1,20 @@
|
||||
model: mixtral-8x7b
|
||||
label:
|
||||
zh_Hans: mixtral-8x7b
|
||||
en_US: mixtral-8x7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 1024
|
||||
@ -15,6 +15,8 @@ help:
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
@ -57,6 +59,9 @@ model_credential_schema:
|
||||
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
|
||||
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
|
||||
- variable: context_size
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
zh_Hans: 上下文大小
|
||||
en_US: Context size
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user