mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 02:07:29 +08:00
Compare commits
158 Commits
0.6.9
...
feat/aws-i
| Author | SHA1 | Date | |
|---|---|---|---|
| 909cc5a49f | |||
| cdc08a434f | |||
| 3f18369ad2 | |||
| db976a1f74 | |||
| e61f5d029a | |||
| eaca892c4e | |||
| 015c26d303 | |||
| 8210637bc5 | |||
| 790543131a | |||
| a40f68cf94 | |||
| adc948e87c | |||
| 742b08e1d5 | |||
| 79e8489942 | |||
| d6fa130cb5 | |||
| 0c92f81efc | |||
| 11fd4a5dcc | |||
| b399e8a359 | |||
| e04fc9b304 | |||
| ea69dc2a7e | |||
| ecc7f130b4 | |||
| 95443bd551 | |||
| 0ce97e6315 | |||
| 25b0a97851 | |||
| 28997772a5 | |||
| b7c72f7a97 | |||
| 9f7b38c068 | |||
| 3b36ba797f | |||
| 4d2e6c3391 | |||
| 3520d35f38 | |||
| 5f104bab57 | |||
| 2050a8b8f0 | |||
| e3544c6ef7 | |||
| 472b976946 | |||
| f62f71a81a | |||
| f426e1b3bd | |||
| 5f870ac950 | |||
| 415816cf35 | |||
| 9103112555 | |||
| 5986841e27 | |||
| 2573b138bf | |||
| 308ce66af5 | |||
| bdad993901 | |||
| 3b62ab564a | |||
| d319d9fc5e | |||
| ea5c8a72e2 | |||
| 3b60c28b3a | |||
| ea0219a5d5 | |||
| 481e7bc6b9 | |||
| 1ccba85c91 | |||
| 2539e56514 | |||
| 3929d289e0 | |||
| 52585aea74 | |||
| 73dee84cab | |||
| efecdccf35 | |||
| da5f2e168a | |||
| 5cdb95be1f | |||
| 7fa735a43b | |||
| 3579fd1b09 | |||
| 237b8fe3d9 | |||
| 02e4de5166 | |||
| 64c8093c1e | |||
| 0797f9bc05 | |||
| 602c4e51ec | |||
| 9f8ca75a81 | |||
| 80a87f36ea | |||
| 63addc9258 | |||
| f32b440c4a | |||
| 6b6afb7708 | |||
| a4041cb40b | |||
| 7749b71fff | |||
| 3006124e6d | |||
| 3d276f4a7f | |||
| b20d173324 | |||
| f44d1e62d2 | |||
| 21ac2afb3a | |||
| f7dd327bc2 | |||
| 09298a32e7 | |||
| 37f292ea91 | |||
| d1dbbc1e33 | |||
| 52ec152dd3 | |||
| c7bddb637b | |||
| 4e3b0c5aea | |||
| b6631cd878 | |||
| c212700341 | |||
| e121788ff5 | |||
| 96460d5ea3 | |||
| 9cf9720efa | |||
| 2d9f55b632 | |||
| 7133a16511 | |||
| a38dfc006e | |||
| 86e7c7321f | |||
| 58db719a2c | |||
| 9abeb99b32 | |||
| d828a7fc35 | |||
| c6f9ea4434 | |||
| fb6843815c | |||
| b97181a793 | |||
| 5d15aca85f | |||
| b98a1a3303 | |||
| 696c5308a9 | |||
| 3542d55e67 | |||
| 3c8a120e51 | |||
| cd24308f20 | |||
| 69190e088e | |||
| d058a234ba | |||
| 41e536109b | |||
| f916aa0f92 | |||
| cdbc260571 | |||
| b234710af9 | |||
| 23498883d4 | |||
| a47e8d0da2 | |||
| 6dd0e07af8 | |||
| b1c9671a60 | |||
| 7aaa1ff270 | |||
| 85698ca4f7 | |||
| 176d91937d | |||
| e0da0744b5 | |||
| 0b4902bdc2 | |||
| e9904e66e6 | |||
| 3de8e8fd6a | |||
| 38a470a873 | |||
| 4308a79e89 | |||
| 93d3350c8c | |||
| 615c009c42 | |||
| a325a294bd | |||
| 4b91383efc | |||
| 18ab63bd37 | |||
| a7fb1ffcd8 | |||
| 11f173693b | |||
| 5b2cd8d03a | |||
| b10e67be3b | |||
| d41c077fac | |||
| 3175a2c76a | |||
| 3b60b712ec | |||
| afed3610fc | |||
| 74f38eacda | |||
| b189faca52 | |||
| d4cd6149ac | |||
| e1cd9aef8f | |||
| ba37275503 | |||
| e01b44af61 | |||
| 72a90074bc | |||
| 705a6e3a8e | |||
| f4a240d225 | |||
| 793f0c1dd6 | |||
| 008edd0eeb | |||
| 9e6b6e7b82 | |||
| 164d6e47b9 | |||
| 88b4d69278 | |||
| 5bcbcd3c57 | |||
| 1b2d862973 | |||
| e6f6a59f3b | |||
| e198bc9b9a | |||
| b7f81f0999 | |||
| eb8dc15ad6 | |||
| 2ee3a1b6f3 | |||
| 0960b17fbc | |||
| 6534566b7e |
86
.github/workflows/api-tests.yml
vendored
86
.github/workflows/api-tests.yml
vendored
@ -4,6 +4,13 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@ -51,7 +58,7 @@ jobs:
|
||||
- name: Run Workflow
|
||||
run: dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
@ -60,6 +67,7 @@ jobs:
|
||||
docker/docker-compose.milvus.yaml
|
||||
docker/docker-compose.pgvecto-rs.yaml
|
||||
docker/docker-compose.pgvector.yaml
|
||||
docker/docker-compose.chroma.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
@ -68,6 +76,82 @@ jobs:
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: dev/pytest/pytest_vdb.sh
|
||||
|
||||
test-in-poetry:
|
||||
name: API Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache-dependency-path: |
|
||||
api/pyproject.toml
|
||||
api/poetry.lock
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check -C api
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install -C api --with dev
|
||||
|
||||
- name: Run Unit tests
|
||||
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
||||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
docker/docker-compose.qdrant.yaml
|
||||
docker/docker-compose.milvus.yaml
|
||||
docker/docker-compose.pgvecto-rs.yaml
|
||||
docker/docker-compose.pgvector.yaml
|
||||
docker/docker-compose.chroma.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
57
.github/workflows/db-migration-test.yml
vendored
Normal file
57
.github/workflows/db-migration-test.yml
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
name: DB Migration Test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/migrations/**
|
||||
|
||||
concurrency:
|
||||
group: db-migration-test-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
db-migration-test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache-dependency-path: |
|
||||
api/pyproject.toml
|
||||
api/poetry.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install -C api
|
||||
|
||||
- name: Set up Middleware
|
||||
uses: hoverkraft-tech/compose-action@v2.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
|
||||
- name: Prepare configs
|
||||
run: |
|
||||
cd api
|
||||
cp .env.example .env
|
||||
|
||||
- name: Run DB Migration
|
||||
run: |
|
||||
cd api
|
||||
poetry run python -m flask db upgrade
|
||||
67
.github/workflows/style.yml
vendored
67
.github/workflows/style.yml
vendored
@ -6,7 +6,7 @@ on:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: dep-${{ github.head_ref || github.run_id }}
|
||||
group: style-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
@ -18,54 +18,92 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
with:
|
||||
files: api/**
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Python dependencies
|
||||
run: pip install ruff dotenv-linter
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry install -C api --only lint
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check ./api
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api ruff check --preview ./api
|
||||
|
||||
- name: Dotenv check
|
||||
run: dotenv-linter ./api/.env.example ./web/.env.example
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
||||
test:
|
||||
name: ESLint and SuperLinter
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: ubuntu-latest
|
||||
needs: python-style
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
with:
|
||||
fetch-depth: 0
|
||||
files: web/**
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 20
|
||||
cache: yarn
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
run: |
|
||||
cd ./web
|
||||
yarn install --frozen-lockfile
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Web style check
|
||||
run: |
|
||||
cd ./web
|
||||
yarn run lint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: yarn run lint
|
||||
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
**.yaml
|
||||
**.yml
|
||||
Dockerfile
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v6
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
DEFAULT_BRANCH: main
|
||||
@ -76,4 +114,5 @@ jobs:
|
||||
VALIDATE_BASH_EXEC: true
|
||||
VALIDATE_GITHUB_ACTIONS: true
|
||||
VALIDATE_DOCKERFILE_HADOLINT: true
|
||||
VALIDATE_XML: true
|
||||
VALIDATE_YAML: true
|
||||
|
||||
7
.github/workflows/tool-test-sdks.yaml
vendored
7
.github/workflows/tool-test-sdks.yaml
vendored
@ -4,6 +4,13 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- sdks/**
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -134,7 +134,8 @@ dmypy.json
|
||||
web/.vscode/settings.json
|
||||
|
||||
# Intellij IDEA Files
|
||||
.idea/
|
||||
.idea/*
|
||||
!.idea/vcs.xml
|
||||
.ideaDataSources/
|
||||
|
||||
api/.env
|
||||
@ -148,6 +149,7 @@ docker/volumes/qdrant/*
|
||||
docker/volumes/etcd/*
|
||||
docker/volumes/minio/*
|
||||
docker/volumes/milvus/*
|
||||
docker/volumes/chroma/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
|
||||
16
.idea/vcs.xml
generated
Normal file
16
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,16 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="IssueNavigationConfiguration">
|
||||
<option name="links">
|
||||
<list>
|
||||
<IssueNavigationLink>
|
||||
<option name="issueRegexp" value="#(\d+)" />
|
||||
<option name="linkRegexp" value="https://github.com/langgenius/dify/issues/$1" />
|
||||
</IssueNavigationLink>
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
@ -36,6 +36,7 @@
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
225
README_AR.md
Normal file
225
README_AR.md
Normal file
@ -0,0 +1,225 @@
|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify Cloud</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">الاستضافة الذاتية</a> ·
|
||||
<a href="https://docs.dify.ai">التوثيق</a> ·
|
||||
<a href="https://cal.com/guchenhe/60-min-meeting">استفسارات الشركات</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Product-F04438"></a>
|
||||
<a href="https://dify.ai/pricing" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/free-pricing?logo=free&color=%20%23155EEF&label=pricing&labelColor=%20%23528bff"></a>
|
||||
<a href="https://discord.gg/FngNHpbcY7" target="_blank">
|
||||
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb"
|
||||
alt="chat on Discord"></a>
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on Twitter"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
<img alt="Commits last month" src="https://img.shields.io/github/commit-activity/m/langgenius/dify?labelColor=%20%2332b583&color=%20%2312b76a"></a>
|
||||
<a href="https://github.com/langgenius/dify/" target="_blank">
|
||||
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
|
||||
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
|
||||
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
<div style="text-align: right;">
|
||||
مشروع Dify هو منصة تطوير تطبيقات الذكاء الصناعي مفتوحة المصدر. تجمع واجهته البديهية بين سير العمل الذكي بالذكاء الاصطناعي وخط أنابيب RAG وقدرات الوكيل وإدارة النماذج وميزات الملاحظة وأكثر من ذلك، مما يتيح لك الانتقال بسرعة من المرحلة التجريبية إلى الإنتاج. إليك قائمة بالميزات الأساسية:
|
||||
</br> </br>
|
||||
|
||||
**1. سير العمل**: قم ببناء واختبار سير عمل الذكاء الاصطناعي القوي على قماش بصري، مستفيدًا من جميع الميزات التالية وأكثر.
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
**2. الدعم الشامل للنماذج**: تكامل سلس مع مئات من LLMs الخاصة / مفتوحة المصدر من عشرات من موفري التحليل والحلول المستضافة ذاتيًا، مما يغطي GPT و Mistral و Llama3 وأي نماذج متوافقة مع واجهة OpenAI API. يمكن العثور على قائمة كاملة بمزودي النموذج المدعومين [هنا](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
**3. بيئة التطوير للأوامر**: واجهة بيئة التطوير المبتكرة لصياغة الأمر ومقارنة أداء النموذج، وإضافة ميزات إضافية مثل تحويل النص إلى كلام إلى تطبيق قائم على الدردشة.
|
||||
|
||||
**4. خط أنابيب RAG**: قدرات RAG الواسعة التي تغطي كل شيء من استيعاب الوثائق إلى الاسترجاع، مع الدعم الفوري لاستخراج النص من ملفات PDF و PPT وتنسيقات الوثائق الشائعة الأخرى.
|
||||
|
||||
**5. قدرات الوكيل**: يمكنك تعريف الوكلاء بناءً على أمر وظيفة LLM أو ReAct، وإضافة أدوات مدمجة أو مخصصة للوكيل. توفر Dify أكثر من 50 أداة مدمجة لوكلاء الذكاء الاصطناعي، مثل البحث في Google و DELL·E وStable Diffusion و WolframAlpha.
|
||||
|
||||
**6. الـ LLMOps**: راقب وتحلل سجلات التطبيق والأداء على مر الزمن. يمكنك تحسين الأوامر والبيانات والنماذج باستمرار استنادًا إلى البيانات الإنتاجية والتعليقات.
|
||||
|
||||
**7.الواجهة الخلفية (Backend) كخدمة**: تأتي جميع عروض Dify مع APIs مطابقة، حتى يمكنك دمج Dify بسهولة في منطق أعمالك الخاص.
|
||||
## مقارنة الميزات
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">الميزة</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI Assistants API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">نهج البرمجة</td>
|
||||
<td align="center">موجّه لـ تطبيق + واجهة برمجة تطبيق (API)</td>
|
||||
<td align="center">برمجة Python</td>
|
||||
<td align="center">موجه لتطبيق</td>
|
||||
<td align="center">واجهة برمجة تطبيق (API)</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">LLMs المدعومة</td>
|
||||
<td align="center">تنوع غني</td>
|
||||
<td align="center">تنوع غني</td>
|
||||
<td align="center">تنوع غني</td>
|
||||
<td align="center">فقط OpenAI</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">محرك RAG</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">الوكيل</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">سير العمل</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">الملاحظة</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ميزات الشركات (SSO / مراقبة الوصول)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">نشر محلي</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
## استخدام Dify
|
||||
- **سحابة </br>**
|
||||
نحن نستضيف [خدمة Dify Cloud](https://dify.ai) لأي شخص لتجربتها بدون أي إعدادات. توفر كل قدرات النسخة التي تمت استضافتها ذاتيًا، وتتضمن 200 أمر GPT-4 مجانًا في خطة الصندوق الرملي.
|
||||
|
||||
- **استضافة ذاتية لنسخة المجتمع Dify</br>**
|
||||
ابدأ سريعًا في تشغيل Dify في بيئتك باستخدام [دليل البدء السريع](#البدء السريع).
|
||||
استخدم [توثيقنا](https://docs.dify.ai) للمزيد من المراجع والتعليمات الأعمق.
|
||||
|
||||
- **مشروع Dify للشركات / المؤسسات</br>**
|
||||
نحن نوفر ميزات إضافية مركزة على الشركات. [جدول اجتماع معنا](https://cal.com/guchenhe/30min) أو [أرسل لنا بريدًا إلكترونيًا](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) لمناقشة احتياجات الشركات. </br>
|
||||
> بالنسبة للشركات الناشئة والشركات الصغيرة التي تستخدم خدمات AWS، تحقق من [Dify Premium على AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) ونشرها في شبكتك الخاصة على AWS VPC بنقرة واحدة. إنها عرض AMI بأسعار معقولة مع خيار إنشاء تطبيقات بشعار وعلامة تجارية مخصصة.
|
||||
## البقاء قدمًا
|
||||
|
||||
قم بإضافة نجمة إلى Dify على GitHub وتلق تنبيهًا فوريًا بالإصدارات الجديدة.
|
||||
|
||||

|
||||
## البداية السريعة
|
||||
> قبل تثبيت Dify، تأكد من أن جهازك يلبي الحد الأدنى من متطلبات النظام التالية:
|
||||
>
|
||||
>- معالج >= 2 نواة
|
||||
>- ذاكرة وصول عشوائي (RAM) >= 4 جيجابايت
|
||||
|
||||
</br>
|
||||
|
||||
أسهل طريقة لبدء تشغيل خادم Dify هي تشغيل ملف [docker-compose.yml](docker/docker-compose.yaml) الخاص بنا. قبل تشغيل أمر التثبيت، تأكد من تثبيت [Docker](https://docs.docker.com/get-docker/) و [Docker Compose](https://docs.docker.com/compose/install/) على جهازك:
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker compose up -d
|
||||
```
|
||||
بعد التشغيل، يمكنك الوصول إلى لوحة تحكم Dify في متصفحك على [http://localhost/install](http://localhost/install) وبدء عملية التهيئة.
|
||||
|
||||
> إذا كنت ترغب في المساهمة في Dify أو القيام بتطوير إضافي، فانظر إلى [دليلنا للنشر من الشفرة (code) المصدرية](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
|
||||
|
||||
## الخطوات التالية
|
||||
|
||||
إذا كنت بحاجة إلى تخصيص التكوين، يرجى الرجوع إلى التعليقات في ملف [docker-compose.yml](docker/docker-compose.yaml) لدينا وتعيين التكوينات البيئية يدويًا. بعد إجراء التغييرات، يرجى تشغيل `docker-compose up -d` مرة أخرى. يمكنك رؤية قائمة كاملة بالمتغيرات البيئية [هنا](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
إذا كنت ترغب في تكوين إعداد متوفر بشكل عالي، فهناك [رسوم بيانية Helm](https://helm.sh/) المساهمة من المجتمع تسمح بنشر Dify على Kubernetes.
|
||||
|
||||
- [رسم بياني Helm من قبل @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [رسم بياني Helm من قبل @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
|
||||
|
||||
## المساهمة
|
||||
|
||||
لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا.
|
||||
في الوقت نفسه، يرجى النظر في دعم Dify عن طريق مشاركته على وسائل التواصل الاجتماعي وفي الفعاليات والمؤتمرات.
|
||||
|
||||
|
||||
> نحن نبحث عن مساهمين لمساعدة في ترجمة Dify إلى لغات أخرى غير اللغة الصينية المندرين أو الإنجليزية. إذا كنت مهتمًا بالمساعدة، يرجى الاطلاع على [README للترجمة](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) لمزيد من المعلومات، واترك لنا تعليقًا في قناة `global-users` على [خادم المجتمع على Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
**المساهمون**
|
||||
|
||||
<a href="https://github.com/langgenius/dify/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=langgenius/dify" />
|
||||
</a>
|
||||
|
||||
## المجتمع والاتصال
|
||||
* [مناقشة Github](https://github.com/langgenius/dify/discussions). الأفضل لـ: مشاركة التعليقات وطرح الأسئلة.
|
||||
* [المشكلات على GitHub](https://github.com/langgenius/dify/issues). الأفضل لـ: الأخطاء التي تواجهها في استخدام Dify.AI، واقتراحات الميزات. انظر [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
* [البريد الإلكتروني](mailto:support@dify.ai?subject=[GitHub]Questions%20About%20Dify). الأفضل لـ: الأسئلة التي تتعلق باستخدام Dify.AI.
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع.
|
||||
* [تويتر](https://twitter.com/dify_ai). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع.
|
||||
|
||||
أو، قم بجدولة اجتماع مباشرة مع أحد أعضاء الفريق:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>نقطة الاتصال</th>
|
||||
<th>الغرض</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><a href='https://cal.com/guchenhe/15min' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/9ebcd111-1205-4d71-83d5-948d70b809f5' alt='Git-Hub-README-Button-3x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
|
||||
<td>استفسارات الأعمال واقتراحات حول المنتج</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><a href='https://cal.com/pinkbanana' target='_blank'><img class="schedule-button" src='https://github.com/langgenius/dify/assets/13230914/d1edd00a-d7e4-4513-be6c-e57038e143fd' alt='Git-Hub-README-Button-2x' style="width: 180px; height: auto; object-fit: contain;"/></a></td>
|
||||
<td>المساهمات والمشكلات وطلبات الميزات</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## تاريخ النجمة
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
|
||||
## الكشف عن الأمان
|
||||
|
||||
لحماية خصوصيتك، يرجى تجنب نشر مشكلات الأمان على GitHub. بدلاً من ذلك، أرسل أسئلتك إلى security@dify.ai وسنقدم لك إجابة أكثر تفصيلاً.
|
||||
|
||||
## الرخصة
|
||||
|
||||
هذا المستودع متاح تحت [رخصة البرنامج الحر Dify](LICENSE)، والتي تعتبر بشكل أساسي Apache 2.0 مع بعض القيود الإضافية.
|
||||
@ -112,6 +112,21 @@ PGVECTOR_USER=postgres
|
||||
PGVECTOR_PASSWORD=postgres
|
||||
PGVECTOR_DATABASE=postgres
|
||||
|
||||
# Tidb Vector configuration
|
||||
TIDB_VECTOR_HOST=xxx.eu-central-1.xxx.aws.tidbcloud.com
|
||||
TIDB_VECTOR_PORT=4000
|
||||
TIDB_VECTOR_USER=xxx.root
|
||||
TIDB_VECTOR_PASSWORD=xxxxxx
|
||||
TIDB_VECTOR_DATABASE=dify
|
||||
|
||||
# Chroma configuration
|
||||
CHROMA_HOST=127.0.0.1
|
||||
CHROMA_PORT=8000
|
||||
CHROMA_TENANT=default_tenant
|
||||
CHROMA_DATABASE=default_database
|
||||
CHROMA_AUTH_PROVIDER=chromadb.auth.token_authn.TokenAuthenticationServerProvider
|
||||
CHROMA_AUTH_CREDENTIALS=difyai123456
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
@ -127,10 +142,11 @@ RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
# smtp configuration
|
||||
SMTP_SERVER=smtp.gmail.com
|
||||
SMTP_PORT=587
|
||||
SMTP_PORT=465
|
||||
SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=false
|
||||
SMTP_USE_TLS=true
|
||||
SMTP_OPPORTUNISTIC_TLS=false
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
@ -182,3 +198,8 @@ LOG_FILE=
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
|
||||
|
||||
# Workflow runtime configuration
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
|
||||
6
api/.vscode/launch.json
vendored
6
api/.vscode/launch.json
vendored
@ -17,7 +17,8 @@
|
||||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"console": "integratedTerminal"
|
||||
"console": "integratedTerminal",
|
||||
"python": "${command:python.interpreterPath}"
|
||||
},
|
||||
{
|
||||
"name": "Python: Flask",
|
||||
@ -36,7 +37,8 @@
|
||||
"--debug"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true
|
||||
"justMyCode": true,
|
||||
"python": "${command:python.interpreterPath}"
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -17,15 +17,30 @@
|
||||
```bash
|
||||
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
|
||||
```
|
||||
4. If you use Anaconda, create a new environment and activate it
|
||||
4. Create environment.
|
||||
- Anaconda
|
||||
If you use Anaconda, create a new environment and activate it
|
||||
```bash
|
||||
conda create --name dify python=3.10
|
||||
conda activate dify
|
||||
```
|
||||
- Poetry
|
||||
If you use Poetry, you don't need to manually create the environment. You can execute `poetry shell` to activate the environment.
|
||||
5. Install dependencies
|
||||
- Anaconda
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
- Poetry
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
In case of contributors missing to update dependencies for `pyproject.toml`, you can perform the following shell instead.
|
||||
```base
|
||||
poetry shell # activate current environment
|
||||
poetry add $(cat requirements.txt) # install dependencies of production and update pyproject.toml
|
||||
poetry add $(cat requirements-dev.txt) --group dev # install dependencies of development and update pyproject.toml
|
||||
```
|
||||
6. Run migrate
|
||||
|
||||
Before the first launch, migrate the database to the latest version.
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
import base64
|
||||
import json
|
||||
import secrets
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from constants.languages import languages
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
@ -17,6 +20,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from services.account_service import RegisterService, TenantService
|
||||
|
||||
|
||||
@click.command('reset-password', help='Reset the account password.')
|
||||
@ -57,7 +61,7 @@ def reset_password(email, new_password, password_confirm):
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
click.echo(click.style('Congratulations!, password has been reset.', fg='green'))
|
||||
click.echo(click.style('Congratulations! Password has been reset.', fg='green'))
|
||||
|
||||
|
||||
@click.command('reset-email', help='Reset the account email.')
|
||||
@ -263,15 +267,15 @@ def migrate_knowledge_vector_database():
|
||||
skipped_count = skipped_count + 1
|
||||
continue
|
||||
collection_name = ''
|
||||
if vector_type == "weaviate":
|
||||
if vector_type == VectorType.WEAVIATE:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'weaviate',
|
||||
"type": VectorType.WEAVIATE,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "qdrant":
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
@ -284,20 +288,20 @@ def migrate_knowledge_vector_database():
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'qdrant',
|
||||
"type": VectorType.QDRANT,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
elif vector_type == "milvus":
|
||||
elif vector_type == VectorType.MILVUS:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'milvus',
|
||||
"type": VectorType.MILVUS,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "relyt":
|
||||
elif vector_type == VectorType.RELYT:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
@ -305,16 +309,16 @@ def migrate_knowledge_vector_database():
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "pgvector":
|
||||
elif vector_type == VectorType.PGVECTOR:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": 'pgvector',
|
||||
"type": VectorType.PGVECTOR,
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"Start to migrate dataset {dataset.id}.")
|
||||
@ -501,6 +505,46 @@ def add_qdrant_doc_id_index(field: str):
|
||||
fg='green'))
|
||||
|
||||
|
||||
@click.command('create-tenant', help='Create account and tenant.')
|
||||
@click.option('--email', prompt=True, help='The email address of the tenant account.')
|
||||
@click.option('--language', prompt=True, help='Account language, default: en-US.')
|
||||
def create_tenant(email: str, language: Optional[str] = None):
|
||||
"""
|
||||
Create tenant account
|
||||
"""
|
||||
if not email:
|
||||
click.echo(click.style('Sorry, email is required.', fg='red'))
|
||||
return
|
||||
|
||||
# Create account
|
||||
email = email.strip()
|
||||
|
||||
if '@' not in email:
|
||||
click.echo(click.style('Sorry, invalid email address.', fg='red'))
|
||||
return
|
||||
|
||||
account_name = email.split('@')[0]
|
||||
|
||||
if language not in languages:
|
||||
language = 'en-US'
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(
|
||||
email=email,
|
||||
name=account_name,
|
||||
password=new_password,
|
||||
language=language
|
||||
)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
click.echo(click.style('Congratulations! Account and tenant created.\n'
|
||||
'Account: {}\nPassword: {}'.format(email, new_password), fg='green'))
|
||||
|
||||
|
||||
def register_commands(app):
|
||||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
@ -508,4 +552,5 @@ def register_commands(app):
|
||||
app.cli.add_command(vdb_migrate)
|
||||
app.cli.add_command(convert_to_agent_apps)
|
||||
app.cli.add_command(add_qdrant_doc_id_index)
|
||||
app.cli.add_command(create_tenant)
|
||||
|
||||
|
||||
@ -70,6 +70,7 @@ DEFAULTS = {
|
||||
'INVITE_EXPIRY_HOURS': 72,
|
||||
'BILLING_ENABLED': 'False',
|
||||
'CAN_REPLACE_LOGO': 'False',
|
||||
'MODEL_LB_ENABLED': 'False',
|
||||
'ETL_TYPE': 'dify',
|
||||
'KEYWORD_STORE': 'jieba',
|
||||
'BATCH_UPLOAD_LIMIT': 20,
|
||||
@ -81,8 +82,9 @@ DEFAULTS = {
|
||||
'INNER_API': 'False',
|
||||
'ENTERPRISE_ENABLED': 'False',
|
||||
'INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH': 1000,
|
||||
'WORKFLOW_MAX_EXECUTION_STEPS': 50,
|
||||
'WORKFLOW_MAX_EXECUTION_TIME': 600,
|
||||
'WORKFLOW_MAX_EXECUTION_STEPS': 500,
|
||||
'WORKFLOW_MAX_EXECUTION_TIME': 1200,
|
||||
'WORKFLOW_CALL_MAX_DEPTH': 5,
|
||||
}
|
||||
|
||||
|
||||
@ -113,7 +115,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.6.9"
|
||||
self.CURRENT_VERSION = "0.6.10"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = get_env('EDITION')
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@ -122,6 +124,7 @@ class Config:
|
||||
self.LOG_FILE = get_env('LOG_FILE')
|
||||
self.LOG_FORMAT = get_env('LOG_FORMAT')
|
||||
self.LOG_DATEFORMAT = get_env('LOG_DATEFORMAT')
|
||||
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
|
||||
|
||||
# The backend URL prefix of the console API.
|
||||
# used to concatenate the login authorization callback or notion integration callback.
|
||||
@ -209,27 +212,41 @@ class Config:
|
||||
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
|
||||
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
|
||||
|
||||
# ------------------------
|
||||
# Code Execution Sandbox Configurations.
|
||||
# ------------------------
|
||||
self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
|
||||
self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
|
||||
|
||||
# ------------------------
|
||||
# File Storage Configurations.
|
||||
# ------------------------
|
||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||
|
||||
# S3 Storage settings
|
||||
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
|
||||
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
|
||||
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
|
||||
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
|
||||
self.S3_REGION = get_env('S3_REGION')
|
||||
self.S3_ADDRESS_STYLE = get_env('S3_ADDRESS_STYLE')
|
||||
|
||||
# Azure Blob Storage settings
|
||||
self.AZURE_BLOB_ACCOUNT_NAME = get_env('AZURE_BLOB_ACCOUNT_NAME')
|
||||
self.AZURE_BLOB_ACCOUNT_KEY = get_env('AZURE_BLOB_ACCOUNT_KEY')
|
||||
self.AZURE_BLOB_CONTAINER_NAME = get_env('AZURE_BLOB_CONTAINER_NAME')
|
||||
self.AZURE_BLOB_ACCOUNT_URL = get_env('AZURE_BLOB_ACCOUNT_URL')
|
||||
|
||||
# Aliyun Storage settings
|
||||
self.ALIYUN_OSS_BUCKET_NAME = get_env('ALIYUN_OSS_BUCKET_NAME')
|
||||
self.ALIYUN_OSS_ACCESS_KEY = get_env('ALIYUN_OSS_ACCESS_KEY')
|
||||
self.ALIYUN_OSS_SECRET_KEY = get_env('ALIYUN_OSS_SECRET_KEY')
|
||||
self.ALIYUN_OSS_ENDPOINT = get_env('ALIYUN_OSS_ENDPOINT')
|
||||
self.ALIYUN_OSS_REGION = get_env('ALIYUN_OSS_REGION')
|
||||
self.ALIYUN_OSS_AUTH_VERSION = get_env('ALIYUN_OSS_AUTH_VERSION')
|
||||
|
||||
# Google Cloud Storage settings
|
||||
self.GOOGLE_STORAGE_BUCKET_NAME = get_env('GOOGLE_STORAGE_BUCKET_NAME')
|
||||
self.GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64 = get_env('GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64')
|
||||
|
||||
@ -239,6 +256,7 @@ class Config:
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
self.KEYWORD_STORE = get_env('KEYWORD_STORE')
|
||||
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
@ -281,6 +299,21 @@ class Config:
|
||||
self.PGVECTOR_PASSWORD = get_env('PGVECTOR_PASSWORD')
|
||||
self.PGVECTOR_DATABASE = get_env('PGVECTOR_DATABASE')
|
||||
|
||||
# tidb-vector settings
|
||||
self.TIDB_VECTOR_HOST = get_env('TIDB_VECTOR_HOST')
|
||||
self.TIDB_VECTOR_PORT = get_env('TIDB_VECTOR_PORT')
|
||||
self.TIDB_VECTOR_USER = get_env('TIDB_VECTOR_USER')
|
||||
self.TIDB_VECTOR_PASSWORD = get_env('TIDB_VECTOR_PASSWORD')
|
||||
self.TIDB_VECTOR_DATABASE = get_env('TIDB_VECTOR_DATABASE')
|
||||
|
||||
# chroma settings
|
||||
self.CHROMA_HOST = get_env('CHROMA_HOST')
|
||||
self.CHROMA_PORT = get_env('CHROMA_PORT')
|
||||
self.CHROMA_TENANT = get_env('CHROMA_TENANT')
|
||||
self.CHROMA_DATABASE = get_env('CHROMA_DATABASE')
|
||||
self.CHROMA_AUTH_PROVIDER = get_env('CHROMA_AUTH_PROVIDER')
|
||||
self.CHROMA_AUTH_CREDENTIALS = get_env('CHROMA_AUTH_CREDENTIALS')
|
||||
|
||||
# ------------------------
|
||||
# Mail Configurations.
|
||||
# ------------------------
|
||||
@ -294,6 +327,7 @@ class Config:
|
||||
self.SMTP_USERNAME = get_env('SMTP_USERNAME')
|
||||
self.SMTP_PASSWORD = get_env('SMTP_PASSWORD')
|
||||
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
|
||||
self.SMTP_OPPORTUNISTIC_TLS = get_bool_env('SMTP_OPPORTUNISTIC_TLS')
|
||||
|
||||
# ------------------------
|
||||
# Workspace Configurations.
|
||||
@ -321,9 +355,23 @@ class Config:
|
||||
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
|
||||
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
|
||||
self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
|
||||
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
|
||||
|
||||
# RAG ETL Configurations.
|
||||
self.ETL_TYPE = get_env('ETL_TYPE')
|
||||
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
|
||||
self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY')
|
||||
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
|
||||
|
||||
# Indexing Configurations.
|
||||
self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH')
|
||||
|
||||
# Tool Configurations.
|
||||
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
|
||||
|
||||
self.WORKFLOW_MAX_EXECUTION_STEPS = int(get_env('WORKFLOW_MAX_EXECUTION_STEPS'))
|
||||
self.WORKFLOW_MAX_EXECUTION_TIME = int(get_env('WORKFLOW_MAX_EXECUTION_TIME'))
|
||||
self.WORKFLOW_CALL_MAX_DEPTH = int(get_env('WORKFLOW_CALL_MAX_DEPTH'))
|
||||
|
||||
# Moderation in app Configurations.
|
||||
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
|
||||
@ -375,24 +423,15 @@ class Config:
|
||||
self.HOSTED_FETCH_APP_TEMPLATES_MODE = get_env('HOSTED_FETCH_APP_TEMPLATES_MODE')
|
||||
self.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = get_env('HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN')
|
||||
|
||||
self.ETL_TYPE = get_env('ETL_TYPE')
|
||||
self.UNSTRUCTURED_API_URL = get_env('UNSTRUCTURED_API_URL')
|
||||
self.UNSTRUCTURED_API_KEY = get_env('UNSTRUCTURED_API_KEY')
|
||||
# Model Load Balancing Configurations.
|
||||
self.MODEL_LB_ENABLED = get_bool_env('MODEL_LB_ENABLED')
|
||||
|
||||
# Platform Billing Configurations.
|
||||
self.BILLING_ENABLED = get_bool_env('BILLING_ENABLED')
|
||||
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
|
||||
|
||||
self.BATCH_UPLOAD_LIMIT = get_env('BATCH_UPLOAD_LIMIT')
|
||||
|
||||
self.CODE_EXECUTION_ENDPOINT = get_env('CODE_EXECUTION_ENDPOINT')
|
||||
self.CODE_EXECUTION_API_KEY = get_env('CODE_EXECUTION_API_KEY')
|
||||
|
||||
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
|
||||
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')
|
||||
|
||||
self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')
|
||||
# ------------------------
|
||||
# Enterprise feature Configurations.
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
# ------------------------
|
||||
self.ENTERPRISE_ENABLED = get_bool_env('ENTERPRISE_ENABLED')
|
||||
|
||||
# ------------------------
|
||||
# Indexing Configurations.
|
||||
# ------------------------
|
||||
self.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH = get_env('INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH')
|
||||
self.CAN_REPLACE_LOGO = get_bool_env('CAN_REPLACE_LOGO')
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -54,4 +54,4 @@ from .explore import (
|
||||
from .tag import tags
|
||||
|
||||
# Import workspace controllers
|
||||
from .workspace import account, members, model_providers, models, tool_providers, workspace
|
||||
from .workspace import account, load_balancing_config, members, model_providers, models, tool_providers, workspace
|
||||
|
||||
@ -85,7 +85,7 @@ class ChatMessageTextApi(Resource):
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
voice=request.form['voice'],
|
||||
streaming=False
|
||||
)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
@ -476,20 +477,22 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = current_app.config['VECTOR_STORE']
|
||||
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
elif vector_type in {"qdrant", "weaviate"}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
]
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unsupported vector db type.")
|
||||
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
class DatasetRetrievalSettingMockApi(Resource):
|
||||
@ -497,20 +500,22 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
if vector_type in {'milvus', 'relyt', 'pgvector'}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
elif vector_type in {'qdrant', 'weaviate'}:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
]
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unsupported vector db type.")
|
||||
match vector_type:
|
||||
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search'
|
||||
]
|
||||
}
|
||||
case VectorType.QDRANT | VectorType.WEAVIATE:
|
||||
return {
|
||||
'retrieval_method': [
|
||||
'semantic_search', 'full_text_search', 'hybrid_search'
|
||||
]
|
||||
}
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
class DatasetErrorDocs(Resource):
|
||||
@setup_required
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc
|
||||
from transformers.hf_argparser import string_to_bool
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -141,7 +143,11 @@ class DatasetDocumentListApi(Resource):
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
search = request.args.get('keyword', default=None, type=str)
|
||||
sort = request.args.get('sort', default='-created_at', type=str)
|
||||
fetch = request.args.get('fetch', default=False, type=bool)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch = string_to_bool(request.args.get('fetch', default='false'))
|
||||
except (ArgumentTypeError, ValueError, Exception) as e:
|
||||
fetch = False
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
@ -924,6 +930,28 @@ class DocumentRetryApi(DocumentResource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
class DocumentRenameApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(document_fields)
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
document = DocumentService.rename_document(dataset_id, document_id, args['name'])
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError('Cannot delete document during indexing.')
|
||||
|
||||
return document
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
@ -950,3 +978,5 @@ api.add_resource(DocumentStatusApi,
|
||||
api.add_resource(DocumentPauseApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause')
|
||||
api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume')
|
||||
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
|
||||
api.add_resource(DocumentRenameApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename')
|
||||
|
||||
@ -1,14 +1,19 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource
|
||||
|
||||
from libs.login import login_required
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from . import api
|
||||
from .wraps import cloud_utm_record
|
||||
from .setup import setup_required
|
||||
from .wraps import account_initialization_required, cloud_utm_record
|
||||
|
||||
|
||||
class FeatureApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
return FeatureService.get_features(current_user.current_tenant_id).dict()
|
||||
|
||||
@ -17,13 +17,19 @@ class VersionApi(Resource):
|
||||
args = parser.parse_args()
|
||||
check_update_url = current_app.config['CHECK_UPDATE_URL']
|
||||
|
||||
if not check_update_url:
|
||||
return {
|
||||
'version': '0.0.0',
|
||||
'release_date': '',
|
||||
'release_notes': '',
|
||||
'can_auto_update': False
|
||||
result = {
|
||||
'version': current_app.config['CURRENT_VERSION'],
|
||||
'release_date': '',
|
||||
'release_notes': '',
|
||||
'can_auto_update': False,
|
||||
'features': {
|
||||
'can_replace_logo': current_app.config['CAN_REPLACE_LOGO'],
|
||||
'model_load_balancing_enabled': current_app.config['MODEL_LB_ENABLED']
|
||||
}
|
||||
}
|
||||
|
||||
if not check_update_url:
|
||||
return result
|
||||
|
||||
try:
|
||||
response = requests.get(check_update_url, {
|
||||
@ -31,20 +37,15 @@ class VersionApi(Resource):
|
||||
})
|
||||
except Exception as error:
|
||||
logging.warning("Check update version error: {}.".format(str(error)))
|
||||
return {
|
||||
'version': args.get('current_version'),
|
||||
'release_date': '',
|
||||
'release_notes': '',
|
||||
'can_auto_update': False
|
||||
}
|
||||
result['version'] = args.get('current_version')
|
||||
return result
|
||||
|
||||
content = json.loads(response.content)
|
||||
return {
|
||||
'version': content['version'],
|
||||
'release_date': content['releaseDate'],
|
||||
'release_notes': content['releaseNotes'],
|
||||
'can_auto_update': content['canAutoUpdate']
|
||||
}
|
||||
result['version'] = content['version']
|
||||
result['release_date'] = content['releaseDate']
|
||||
result['release_notes'] = content['releaseNotes']
|
||||
result['can_auto_update'] = content['canAutoUpdate']
|
||||
return result
|
||||
|
||||
|
||||
api.add_resource(VersionApi, '/version')
|
||||
|
||||
106
api/controllers/console/workspace/load_balancing_config.py
Normal file
106
api/controllers/console/workspace/load_balancing_config.py
Normal file
@ -0,0 +1,106 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# validate model load balancing credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type'],
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
# validate model load balancing config credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
||||
result = True
|
||||
error = None
|
||||
|
||||
try:
|
||||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type'],
|
||||
credentials=args['credentials'],
|
||||
config_id=config_id,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
error = str(ex)
|
||||
|
||||
response = {'result': 'success' if result else 'error'}
|
||||
|
||||
if not result:
|
||||
response['error'] = error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Load Balancing Config
|
||||
api.add_resource(LoadBalancingCredentialsValidateApi,
|
||||
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate')
|
||||
|
||||
api.add_resource(LoadBalancingConfigCredentialsValidateApi,
|
||||
'/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate')
|
||||
@ -12,6 +12,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
@ -104,21 +105,56 @@ class ModelProviderModelApi(Resource):
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json')
|
||||
parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json')
|
||||
parser.add_argument('config_from', type=str, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
||||
try:
|
||||
model_provider_service.save_model_credentials(
|
||||
if ('load_balancing' in args and args['load_balancing'] and
|
||||
'enabled' in args['load_balancing'] and args['load_balancing']['enabled']):
|
||||
if 'configs' not in args['load_balancing']:
|
||||
raise ValueError('invalid load balancing configs')
|
||||
|
||||
# save load balancing configs
|
||||
model_load_balancing_service.update_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type'],
|
||||
credentials=args['credentials']
|
||||
configs=args['load_balancing']['configs']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
# enable load balancing
|
||||
model_load_balancing_service.enable_model_load_balancing(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type']
|
||||
)
|
||||
else:
|
||||
# disable load balancing
|
||||
model_load_balancing_service.disable_model_load_balancing(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
if args.get('config_from', '') != 'predefined-model':
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.save_model_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type'],
|
||||
credentials=args['credentials']
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
@ -170,11 +206,73 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
model=args['model']
|
||||
)
|
||||
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
return {
|
||||
"credentials": credentials
|
||||
"credentials": credentials,
|
||||
"load_balancing": {
|
||||
"enabled": is_load_balancing_enabled,
|
||||
"configs": load_balancing_configs
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.enable_model(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ModelProviderModelDisableApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('model', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('model_type', type=str, required=True, nullable=False,
|
||||
choices=[mt.value for mt in ModelType], location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.disable_model(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args['model'],
|
||||
model_type=args['model_type']
|
||||
)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class ModelProviderModelValidateApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@ -259,6 +357,10 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
|
||||
api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers/<string:provider>/models/enable',
|
||||
endpoint='model-provider-model-enable')
|
||||
api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers/<string:provider>/models/disable',
|
||||
endpoint='model-provider-model-disable')
|
||||
api.add_resource(ModelProviderModelCredentialApi,
|
||||
'/workspaces/current/model-providers/<string:provider>/models/credentials')
|
||||
api.add_resource(ModelProviderModelValidateApi,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
@ -19,10 +20,12 @@ def _validate_name(name):
|
||||
return name
|
||||
|
||||
|
||||
class DatasetApi(DatasetApiResource):
|
||||
"""Resource for get datasets."""
|
||||
class DatasetListApi(DatasetApiResource):
|
||||
"""Resource for datasets."""
|
||||
|
||||
def get(self, tenant_id):
|
||||
"""Resource for getting datasets."""
|
||||
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
@ -65,9 +68,9 @@ class DatasetApi(DatasetApiResource):
|
||||
}
|
||||
return response, 200
|
||||
|
||||
"""Resource for datasets."""
|
||||
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
@ -89,6 +92,31 @@ class DatasetApi(DatasetApiResource):
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
|
||||
class DatasetApi(DatasetApiResource):
|
||||
"""Resource for dataset."""
|
||||
|
||||
api.add_resource(DatasetApi, '/datasets')
|
||||
def delete(self, _, dataset_id):
|
||||
"""
|
||||
Deletes a dataset given its ID.
|
||||
|
||||
Args:
|
||||
dataset_id (UUID): The ID of the dataset to be deleted.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with a key 'result' and a value 'success'
|
||||
if the dataset was successfully deleted. Omitted in HTTP response.
|
||||
int: HTTP status code 204 indicating that the operation was successful.
|
||||
|
||||
Raises:
|
||||
NotFound: If the dataset with the given ID does not exist.
|
||||
"""
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
return {'result': 'success'}, 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
api.add_resource(DatasetListApi, '/datasets')
|
||||
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
|
||||
|
||||
@ -8,7 +8,7 @@ from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
@ -39,17 +39,17 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
raise Forbidden("The app no longer exists.")
|
||||
|
||||
if app_model.status != 'normal':
|
||||
raise NotFound()
|
||||
raise Forbidden("The app's status is abnormal.")
|
||||
|
||||
if not app_model.enable_api:
|
||||
raise NotFound()
|
||||
raise Forbidden("The app's API service has been disabled.")
|
||||
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise NotFound()
|
||||
raise Forbidden("The workspace's status is archived.")
|
||||
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ class TextApi(WebApiResource):
|
||||
app_model=app_model,
|
||||
text=request.form['text'],
|
||||
end_user=end_user.external_user_id,
|
||||
voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
voice=request.form['voice'] if request.form.get('voice') else None,
|
||||
streaming=False
|
||||
)
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
@ -128,6 +129,8 @@ class BaseAgentRunner(AppRunner):
|
||||
self.files = application_generate_entity.files
|
||||
else:
|
||||
self.files = []
|
||||
self.query = None
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
||||
-> AgentChatAppGenerateEntity:
|
||||
@ -184,21 +187,11 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
@ -279,20 +272,10 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
parameter_type = 'string'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
parameter_type = 'boolean'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
parameter_type = 'number'
|
||||
elif parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
for option in parameter.options:
|
||||
enum.append(option.value)
|
||||
parameter_type = 'string'
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
@ -464,7 +447,7 @@ class BaseAgentRunner(AppRunner):
|
||||
for message in messages:
|
||||
if message.id == self.message.id:
|
||||
continue
|
||||
|
||||
|
||||
result.append(self.organize_agent_user_prompt(message))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
|
||||
@ -15,6 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
@ -373,7 +374,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(self) -> list[PromptMessage]:
|
||||
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
"""
|
||||
@ -381,6 +382,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
scratchpad: list[AgentScratchpadUnit] = []
|
||||
current_scratchpad: AgentScratchpadUnit = None
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory
|
||||
).get_prompt()
|
||||
|
||||
for message in self.history_prompt_messages:
|
||||
if isinstance(message, AssistantPromptMessage):
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
|
||||
@ -32,9 +32,6 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._historic_prompt_messages
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
if not agent_scratchpad:
|
||||
@ -57,6 +54,13 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
query_messages = UserPromptMessage(content=self._query)
|
||||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([
|
||||
system_message,
|
||||
query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
])
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
@ -65,6 +69,8 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
UserPromptMessage(content='continue')
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
|
||||
messages = [system_message, *historic_messages, query_messages]
|
||||
|
||||
# join all messages
|
||||
|
||||
@ -19,11 +19,11 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self) -> str:
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
|
||||
"""
|
||||
Organize historic prompt
|
||||
"""
|
||||
historic_prompt_messages = self._historic_prompt_messages
|
||||
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
|
||||
historic_prompt = ""
|
||||
|
||||
for message in historic_prompt_messages:
|
||||
|
||||
@ -17,6 +17,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Message
|
||||
@ -24,21 +25,18 @@ from models.model import Message
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
def run(self,
|
||||
message: Message, query: str, **kwargs: Any
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
prompt_template = app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_messages = self.history_prompt_messages
|
||||
prompt_messages = self._init_system_message(prompt_template, prompt_messages)
|
||||
prompt_messages = self._organize_user_query(query, prompt_messages)
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
@ -81,6 +79,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
@ -203,7 +202,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
prompt_messages.append(assistant_message)
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
@ -265,12 +264,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
}
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
prompt_messages = self._organize_assistant_message(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_response=tool_response['tool_response'],
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
if tool_response['tool_response'] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response['tool_response'],
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
@ -300,8 +301,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
iteration_step += 1
|
||||
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
@ -393,24 +392,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
|
||||
prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize assistant message
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
if tool_response is not None:
|
||||
prompt_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
@ -428,4 +409,26 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for content in prompt_message.content
|
||||
])
|
||||
|
||||
return prompt_messages
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [
|
||||
*self.history_prompt_messages,
|
||||
*query_prompt_messages,
|
||||
*self._current_thoughts
|
||||
]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
|
||||
@ -107,8 +107,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
@ -410,6 +410,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
# check if it's the first node in the iteration
|
||||
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
|
||||
if not target_node:
|
||||
return []
|
||||
|
||||
node_iteration_id = target_node.get('data', {}).get('iteration_id')
|
||||
# get iteration start node id
|
||||
for node in nodes:
|
||||
if node.get('id') == node_iteration_id:
|
||||
if node.get('data', {}).get('start_node_id') == target_node_id:
|
||||
return [target_node_id]
|
||||
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
@ -514,6 +526,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
value = None
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
@ -525,6 +538,20 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
if route_chunk_node_id == 'sys':
|
||||
# system variable
|
||||
value = self._workflow_system_variables.get(SystemVariable.value_of(value_selector[1]))
|
||||
elif route_chunk_node_id in self._iteration_nested_relations:
|
||||
# it's a iteration variable
|
||||
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
||||
continue
|
||||
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
|
||||
iterator = iteration_state.inputs
|
||||
if not iterator:
|
||||
continue
|
||||
iterator_selector = iterator.get('iterator_selector', [])
|
||||
if value_selector[1] == 'index':
|
||||
value = iteration_state.current_index
|
||||
elif value_selector[1] == 'item':
|
||||
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
|
||||
iterator_selector) else None
|
||||
else:
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
|
||||
@ -554,7 +581,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
value = value.get(key)
|
||||
|
||||
if value:
|
||||
if value is not None:
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@ -16,11 +16,11 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
@ -45,8 +45,11 @@ class AppRunner:
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
@ -73,9 +76,7 @@ class AppRunner:
|
||||
query=query
|
||||
)
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
@ -89,8 +90,10 @@ class AppRunner:
|
||||
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_messages: list[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
@ -107,9 +110,7 @@ class AppRunner:
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
|
||||
@ -37,6 +37,7 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
@ -317,29 +318,30 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
"""
|
||||
model_config = self._model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
completion_tokens = model_instance.get_llm_num_tokens(
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
# transform usage
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
|
||||
@ -16,6 +16,7 @@ class ModelStatus(Enum):
|
||||
NO_CONFIGURE = "no-configure"
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
NO_PERMISSION = "no-permission"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class SimpleModelProviderEntity(BaseModel):
|
||||
@ -43,12 +44,19 @@ class SimpleModelProviderEntity(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ModelWithProviderEntity(ProviderModel):
|
||||
class ProviderModelWithStatusEntity(ProviderModel):
|
||||
"""
|
||||
Model class for model response.
|
||||
"""
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
|
||||
|
||||
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
provider: SimpleModelProviderEntity
|
||||
status: ModelStatus
|
||||
|
||||
|
||||
class DefaultModelProviderEntity(BaseModel):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
@ -8,7 +9,12 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
ModelSettings,
|
||||
SystemConfiguration,
|
||||
SystemConfigurationStatus,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
||||
@ -22,7 +28,14 @@ from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderModel,
|
||||
ProviderModelSetting,
|
||||
ProviderType,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -39,6 +52,7 @@ class ProviderConfiguration(BaseModel):
|
||||
using_provider_type: ProviderType
|
||||
system_configuration: SystemConfiguration
|
||||
custom_configuration: CustomConfiguration
|
||||
model_settings: list[ModelSettings]
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@ -62,6 +76,14 @@ class ProviderConfiguration(BaseModel):
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
if self.model_settings:
|
||||
# check if model is disabled by admin
|
||||
for model_setting in self.model_settings:
|
||||
if (model_setting.model_type == model_type
|
||||
and model_setting.model == model):
|
||||
if not model_setting.enabled:
|
||||
raise ValueError(f'Model {model} is disabled.')
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
restrict_models = []
|
||||
for quota_configuration in self.system_configuration.quota_configurations:
|
||||
@ -80,15 +102,17 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return copy_credentials
|
||||
else:
|
||||
credentials = None
|
||||
if self.custom_configuration.models:
|
||||
for model_configuration in self.custom_configuration.models:
|
||||
if model_configuration.model_type == model_type and model_configuration.model == model:
|
||||
return model_configuration.credentials
|
||||
credentials = model_configuration.credentials
|
||||
break
|
||||
|
||||
if self.custom_configuration.provider:
|
||||
return self.custom_configuration.provider.credentials
|
||||
else:
|
||||
return None
|
||||
credentials = self.custom_configuration.provider.credentials
|
||||
|
||||
return credentials
|
||||
|
||||
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
||||
"""
|
||||
@ -130,7 +154,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return credentials
|
||||
|
||||
# Obfuscate credentials
|
||||
return self._obfuscated_credentials(
|
||||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
@ -151,7 +175,7 @@ class ProviderConfiguration(BaseModel):
|
||||
).first()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema else []
|
||||
)
|
||||
@ -274,7 +298,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return credentials
|
||||
|
||||
# Obfuscate credentials
|
||||
return self._obfuscated_credentials(
|
||||
return self.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
@ -302,7 +326,7 @@ class ProviderConfiguration(BaseModel):
|
||||
).first()
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema else []
|
||||
)
|
||||
@ -402,6 +426,160 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Enable model.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = True
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=True
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Disable model.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
if model_setting:
|
||||
model_setting.enabled = False
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
enabled=False
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
|
||||
"""
|
||||
Get provider model setting.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
return db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Enable model load balancing.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model
|
||||
).count()
|
||||
|
||||
if load_balancing_config_count <= 1:
|
||||
raise ValueError('Model load balancing configuration must be more than 1.')
|
||||
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = True
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=True
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
||||
"""
|
||||
Disable model load balancing.
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_setting = db.session.query(ProviderModelSetting) \
|
||||
.filter(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name == self.provider.provider,
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_name == model
|
||||
).first()
|
||||
|
||||
if model_setting:
|
||||
model_setting.load_balancing_enabled = False
|
||||
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
else:
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
load_balancing_enabled=False
|
||||
)
|
||||
db.session.add(model_setting)
|
||||
db.session.commit()
|
||||
|
||||
return model_setting
|
||||
|
||||
def get_provider_instance(self) -> ModelProvider:
|
||||
"""
|
||||
Get provider instance.
|
||||
@ -453,7 +631,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||
def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
|
||||
@ -467,7 +645,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
||||
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
||||
"""
|
||||
Obfuscated credentials.
|
||||
|
||||
@ -476,7 +654,7 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self._extract_secret_variables(
|
||||
credential_secret_variables = self.extract_secret_variables(
|
||||
credential_form_schemas
|
||||
)
|
||||
|
||||
@ -522,15 +700,22 @@ class ProviderConfiguration(BaseModel):
|
||||
else:
|
||||
model_types = provider_instance.get_provider_schema().supported_model_types
|
||||
|
||||
# Group model settings by model type and model
|
||||
model_setting_map = defaultdict(dict)
|
||||
for model_setting in self.model_settings:
|
||||
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
|
||||
|
||||
if self.using_provider_type == ProviderType.SYSTEM:
|
||||
provider_models = self._get_system_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance
|
||||
provider_instance=provider_instance,
|
||||
model_setting_map=model_setting_map
|
||||
)
|
||||
else:
|
||||
provider_models = self._get_custom_provider_models(
|
||||
model_types=model_types,
|
||||
provider_instance=provider_instance
|
||||
provider_instance=provider_instance,
|
||||
model_setting_map=model_setting_map
|
||||
)
|
||||
|
||||
if only_active:
|
||||
@ -541,18 +726,27 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_system_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get system provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_instance: provider instance
|
||||
:param model_setting_map: model setting map
|
||||
:return:
|
||||
"""
|
||||
provider_models = []
|
||||
for model_type in model_types:
|
||||
provider_models.extend(
|
||||
[
|
||||
for m in provider_instance.models(model_type):
|
||||
status = ModelStatus.ACTIVE
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
model_setting = model_setting_map[m.model_type][m.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=m.model,
|
||||
label=m.label,
|
||||
@ -562,11 +756,9 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=m.model_properties,
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
status=status
|
||||
)
|
||||
for m in provider_instance.models(model_type)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if self.provider.provider not in original_provider_configurate_methods:
|
||||
original_provider_configurate_methods[self.provider.provider] = []
|
||||
@ -586,7 +778,8 @@ class ProviderConfiguration(BaseModel):
|
||||
break
|
||||
|
||||
if should_use_custom_model:
|
||||
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
if original_provider_configurate_methods[self.provider.provider] == [
|
||||
ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
||||
# only customizable model
|
||||
for restrict_model in restrict_models:
|
||||
copy_credentials = self.system_configuration.credentials.copy()
|
||||
@ -611,6 +804,13 @@ class ProviderConfiguration(BaseModel):
|
||||
if custom_model_schema.model_type not in model_types:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
if (custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
@ -621,7 +821,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
status=status
|
||||
)
|
||||
)
|
||||
|
||||
@ -632,16 +832,20 @@ class ProviderConfiguration(BaseModel):
|
||||
m.status = ModelStatus.NO_PERMISSION
|
||||
elif not quota_configuration.is_valid:
|
||||
m.status = ModelStatus.QUOTA_EXCEEDED
|
||||
|
||||
return provider_models
|
||||
|
||||
def _get_custom_provider_models(self,
|
||||
model_types: list[ModelType],
|
||||
provider_instance: ModelProvider) -> list[ModelWithProviderEntity]:
|
||||
provider_instance: ModelProvider,
|
||||
model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
|
||||
-> list[ModelWithProviderEntity]:
|
||||
"""
|
||||
Get custom provider models.
|
||||
|
||||
:param model_types: model types
|
||||
:param provider_instance: provider instance
|
||||
:param model_setting_map: model setting map
|
||||
:return:
|
||||
"""
|
||||
provider_models = []
|
||||
@ -656,6 +860,16 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
models = provider_instance.models(model_type)
|
||||
for m in models:
|
||||
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||
load_balancing_enabled = False
|
||||
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
||||
model_setting = model_setting_map[m.model_type][m.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
if len(model_setting.load_balancing_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=m.model,
|
||||
@ -666,7 +880,8 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=m.model_properties,
|
||||
deprecated=m.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
||||
status=status,
|
||||
load_balancing_enabled=load_balancing_enabled
|
||||
)
|
||||
)
|
||||
|
||||
@ -690,6 +905,17 @@ class ProviderConfiguration(BaseModel):
|
||||
if not custom_model_schema:
|
||||
continue
|
||||
|
||||
status = ModelStatus.ACTIVE
|
||||
load_balancing_enabled = False
|
||||
if (custom_model_schema.model_type in model_setting_map
|
||||
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
|
||||
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
||||
if model_setting.enabled is False:
|
||||
status = ModelStatus.DISABLED
|
||||
|
||||
if len(model_setting.load_balancing_configs) > 1:
|
||||
load_balancing_enabled = True
|
||||
|
||||
provider_models.append(
|
||||
ModelWithProviderEntity(
|
||||
model=custom_model_schema.model,
|
||||
@ -700,7 +926,8 @@ class ProviderConfiguration(BaseModel):
|
||||
model_properties=custom_model_schema.model_properties,
|
||||
deprecated=custom_model_schema.deprecated,
|
||||
provider=SimpleModelProviderEntity(self.provider),
|
||||
status=ModelStatus.ACTIVE
|
||||
status=status,
|
||||
load_balancing_enabled=load_balancing_enabled
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -72,3 +72,22 @@ class CustomConfiguration(BaseModel):
|
||||
"""
|
||||
provider: Optional[CustomProviderConfiguration] = None
|
||||
models: list[CustomModelConfiguration] = []
|
||||
|
||||
|
||||
class ModelLoadBalancingConfiguration(BaseModel):
|
||||
"""
|
||||
Class for model load balancing configuration.
|
||||
"""
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict
|
||||
|
||||
|
||||
class ModelSettings(BaseModel):
|
||||
"""
|
||||
Model class for model settings.
|
||||
"""
|
||||
model: str
|
||||
model_type: ModelType
|
||||
enabled: bool = True
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.utils.position_helper import sort_to_dict_by_position_map
|
||||
from core.helper.position_helper import sort_to_dict_by_position_map
|
||||
|
||||
|
||||
class ExtensionModule(enum.Enum):
|
||||
|
||||
@ -9,6 +9,7 @@ from extensions.ext_redis import redis_client
|
||||
class ProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "provider"
|
||||
MODEL = "provider_model"
|
||||
LOAD_BALANCING_MODEL = "load_balancing_provider_model"
|
||||
|
||||
|
||||
class ProviderCredentialsCache:
|
||||
|
||||
@ -12,7 +12,6 @@ from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from core.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@ -20,12 +19,16 @@ from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import Document
|
||||
from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
|
||||
from core.splitter.text_splitter import TextSplitter
|
||||
from core.rag.splitter.fixed_text_splitter import (
|
||||
EnhanceRecursiveCharacterTextSplitter,
|
||||
FixedRecursiveCharacterTextSplitter,
|
||||
)
|
||||
from core.rag.splitter.text_splitter import TextSplitter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@ -283,11 +286,7 @@ class IndexingRunner:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
tokens += embedding_model_type_instance.get_num_tokens(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
tokens += embedding_model_instance.get_text_embedding_num_tokens(
|
||||
texts=[self.filter_string(document.page_content)]
|
||||
)
|
||||
|
||||
@ -655,10 +654,6 @@ class IndexingRunner:
|
||||
tokens = 0
|
||||
chunk_size = 10
|
||||
|
||||
embedding_model_type_instance = None
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(target=self._process_keyword_index,
|
||||
args=(current_app._get_current_object(),
|
||||
@ -671,8 +666,7 @@ class IndexingRunner:
|
||||
chunk_documents = documents[i:i + chunk_size]
|
||||
futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
|
||||
chunk_documents, dataset,
|
||||
dataset_document, embedding_model_instance,
|
||||
embedding_model_type_instance))
|
||||
dataset_document, embedding_model_instance))
|
||||
|
||||
for future in futures:
|
||||
tokens += future.result()
|
||||
@ -713,7 +707,7 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
|
||||
embedding_model_instance, embedding_model_type_instance):
|
||||
embedding_model_instance):
|
||||
with flask_app.app_context():
|
||||
# check document is paused
|
||||
self._check_document_paused_status(dataset_document.id)
|
||||
@ -721,9 +715,7 @@ class IndexingRunner:
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
|
||||
tokens += sum(
|
||||
embedding_model_type_instance.get_num_tokens(
|
||||
embedding_model_instance.model,
|
||||
embedding_model_instance.credentials,
|
||||
embedding_model_instance.get_text_embedding_num_tokens(
|
||||
[document.page_content]
|
||||
)
|
||||
for document in chunk_documents
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_manager import ModelInstance
|
||||
@ -9,8 +11,6 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
@ -21,7 +21,7 @@ class TokenBufferMemory:
|
||||
self.model_instance = model_instance
|
||||
|
||||
def get_history_prompt_messages(self, max_token_limit: int = 2000,
|
||||
message_limit: int = 10) -> list[PromptMessage]:
|
||||
message_limit: Optional[int] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
@ -30,10 +30,15 @@ class TokenBufferMemory:
|
||||
app_record = self.conversation.app
|
||||
|
||||
# fetch limited messages, and return reversed
|
||||
messages = db.session.query(Message).filter(
|
||||
query = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
Message.answer != ''
|
||||
).order_by(Message.created_at.desc()).limit(message_limit).all()
|
||||
).order_by(Message.created_at.desc())
|
||||
|
||||
if message_limit and message_limit > 0:
|
||||
messages = query.limit(message_limit).all()
|
||||
else:
|
||||
messages = query.all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
message_file_parser = MessageFileParser(
|
||||
@ -78,12 +83,7 @@ class TokenBufferMemory:
|
||||
return []
|
||||
|
||||
# prune the chat message if it exceeds the max token limit
|
||||
provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider)
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.model_instance.model,
|
||||
self.model_instance.credentials,
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
@ -91,9 +91,7 @@ class TokenBufferMemory:
|
||||
pruned_memory = []
|
||||
while curr_message_tokens > max_token_limit and prompt_messages:
|
||||
pruned_memory.append(prompt_messages.pop(0))
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.model_instance.model,
|
||||
self.model_instance.credentials,
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
@ -102,7 +100,7 @@ class TokenBufferMemory:
|
||||
def get_history_prompt_text(self, human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int = 10) -> str:
|
||||
message_limit: Optional[int] = None) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import IO, Optional, Union, cast
|
||||
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
@ -9,6 +12,7 @@ from core.model_runtime.entities.message_entities import PromptMessage, PromptMe
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
@ -16,6 +20,10 @@ from core.model_runtime.model_providers.__base.speech2text_model import Speech2T
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
from core.provider_manager import ProviderManager
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelInstance:
|
||||
@ -29,6 +37,12 @@ class ModelInstance:
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
||||
self.load_balancing_manager = self._get_load_balancing_manager(
|
||||
configuration=provider_model_bundle.configuration,
|
||||
model_type=provider_model_bundle.model_type_instance.model_type,
|
||||
model=model,
|
||||
credentials=self.credentials
|
||||
)
|
||||
|
||||
def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
|
||||
"""
|
||||
@ -37,8 +51,10 @@ class ModelInstance:
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=provider_model_bundle.model_type_instance.model_type,
|
||||
configuration = provider_model_bundle.configuration
|
||||
model_type = provider_model_bundle.model_type_instance.model_type
|
||||
credentials = configuration.get_current_credentials(
|
||||
model_type=model_type,
|
||||
model=model
|
||||
)
|
||||
|
||||
@ -47,6 +63,43 @@ class ModelInstance:
|
||||
|
||||
return credentials
|
||||
|
||||
def _get_load_balancing_manager(self, configuration: ProviderConfiguration,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict) -> Optional["LBModelManager"]:
|
||||
"""
|
||||
Get load balancing model credentials
|
||||
:param configuration: provider configuration
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
if configuration.model_settings and configuration.using_provider_type == ProviderType.CUSTOM:
|
||||
current_model_setting = None
|
||||
# check if model is disabled by admin
|
||||
for model_setting in configuration.model_settings:
|
||||
if (model_setting.model_type == model_type
|
||||
and model_setting.model == model):
|
||||
current_model_setting = model_setting
|
||||
break
|
||||
|
||||
# check if load balancing is enabled
|
||||
if current_model_setting and current_model_setting.load_balancing_configs:
|
||||
# use load balancing proxy to choose credentials
|
||||
lb_model_manager = LBModelManager(
|
||||
tenant_id=configuration.tenant_id,
|
||||
provider=configuration.provider.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
load_balancing_configs=current_model_setting.load_balancing_configs,
|
||||
managed_credentials=credentials if configuration.custom_configuration.provider else None
|
||||
)
|
||||
|
||||
return lb_model_manager
|
||||
|
||||
return None
|
||||
|
||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
|
||||
@ -67,7 +120,8 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -79,6 +133,27 @@ class ModelInstance:
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(self, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools
|
||||
)
|
||||
|
||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
@ -92,13 +167,32 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for text embedding
|
||||
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts
|
||||
)
|
||||
|
||||
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
@ -117,7 +211,8 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
@ -140,7 +235,8 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
@ -160,7 +256,8 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
@ -183,7 +280,8 @@ class ModelInstance:
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.invoke(
|
||||
return self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
@ -193,7 +291,44 @@ class ModelInstance:
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
def get_tts_voices(self, language: str) -> list:
|
||||
def _round_robin_invoke(self, function: callable, *args, **kwargs):
|
||||
"""
|
||||
Round-robin invoke
|
||||
:param function: function to invoke
|
||||
:param args: function args
|
||||
:param kwargs: function kwargs
|
||||
:return:
|
||||
"""
|
||||
if not self.load_balancing_manager:
|
||||
return function(*args, **kwargs)
|
||||
|
||||
last_exception = None
|
||||
while True:
|
||||
lb_config = self.load_balancing_manager.fetch_next()
|
||||
if not lb_config:
|
||||
if not last_exception:
|
||||
raise ProviderTokenNotInitError("Model credentials is not initialized.")
|
||||
else:
|
||||
raise last_exception
|
||||
|
||||
try:
|
||||
if 'credentials' in kwargs:
|
||||
del kwargs['credentials']
|
||||
return function(*args, **kwargs, credentials=lb_config.credentials)
|
||||
except InvokeRateLimitError as e:
|
||||
# expire in 60 seconds
|
||||
self.load_balancing_manager.cooldown(lb_config, expire=60)
|
||||
last_exception = e
|
||||
continue
|
||||
except (InvokeAuthorizationError, InvokeConnectionError) as e:
|
||||
# expire in 10 seconds
|
||||
self.load_balancing_manager.cooldown(lb_config, expire=10)
|
||||
last_exception = e
|
||||
continue
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def get_tts_voices(self, language: Optional[str] = None) -> list:
|
||||
"""
|
||||
Invoke large language tts model voices
|
||||
|
||||
@ -226,6 +361,7 @@ class ModelManager:
|
||||
"""
|
||||
if not provider:
|
||||
return self.get_default_model_instance(tenant_id, model_type)
|
||||
|
||||
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
@ -255,3 +391,141 @@ class ModelManager:
|
||||
model_type=model_type,
|
||||
model=default_model_entity.model
|
||||
)
|
||||
|
||||
|
||||
class LBModelManager:
|
||||
def __init__(self, tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||
managed_credentials: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Load balancing model manager
|
||||
:param load_balancing_configs: all load balancing configurations
|
||||
:param managed_credentials: credentials if load balancing configuration name is __inherit__
|
||||
"""
|
||||
self._tenant_id = tenant_id
|
||||
self._provider = provider
|
||||
self._model_type = model_type
|
||||
self._model = model
|
||||
self._load_balancing_configs = load_balancing_configs
|
||||
|
||||
for load_balancing_config in self._load_balancing_configs:
|
||||
if load_balancing_config.name == "__inherit__":
|
||||
if not managed_credentials:
|
||||
# remove __inherit__ if managed credentials is not provided
|
||||
self._load_balancing_configs.remove(load_balancing_config)
|
||||
else:
|
||||
load_balancing_config.credentials = managed_credentials
|
||||
|
||||
def fetch_next(self) -> Optional[ModelLoadBalancingConfiguration]:
|
||||
"""
|
||||
Get next model load balancing config
|
||||
Strategy: Round Robin
|
||||
:return:
|
||||
"""
|
||||
cache_key = "model_lb_index:{}:{}:{}:{}".format(
|
||||
self._tenant_id,
|
||||
self._provider,
|
||||
self._model_type.value,
|
||||
self._model
|
||||
)
|
||||
|
||||
cooldown_load_balancing_configs = []
|
||||
max_index = len(self._load_balancing_configs)
|
||||
|
||||
while True:
|
||||
current_index = redis_client.incr(cache_key)
|
||||
if current_index >= 10000000:
|
||||
current_index = 1
|
||||
redis_client.set(cache_key, current_index)
|
||||
|
||||
redis_client.expire(cache_key, 3600)
|
||||
if current_index > max_index:
|
||||
current_index = current_index % max_index
|
||||
|
||||
real_index = current_index - 1
|
||||
if real_index > max_index:
|
||||
real_index = 0
|
||||
|
||||
config = self._load_balancing_configs[real_index]
|
||||
|
||||
if self.in_cooldown(config):
|
||||
cooldown_load_balancing_configs.append(config)
|
||||
if len(cooldown_load_balancing_configs) >= len(self._load_balancing_configs):
|
||||
# all configs are in cooldown
|
||||
return None
|
||||
|
||||
continue
|
||||
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n"
|
||||
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
|
||||
f"model_type: {self._model_type.value}\nmodel: {self._model}")
|
||||
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
|
||||
"""
|
||||
Cooldown model load balancing config
|
||||
:param config: model load balancing config
|
||||
:param expire: cooldown time
|
||||
:return:
|
||||
"""
|
||||
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
||||
self._tenant_id,
|
||||
self._provider,
|
||||
self._model_type.value,
|
||||
self._model,
|
||||
config.id
|
||||
)
|
||||
|
||||
redis_client.setex(cooldown_cache_key, expire, 'true')
|
||||
|
||||
def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
|
||||
"""
|
||||
Check if model load balancing config is in cooldown
|
||||
:param config: model load balancing config
|
||||
:return:
|
||||
"""
|
||||
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
||||
self._tenant_id,
|
||||
self._provider,
|
||||
self._model_type.value,
|
||||
self._model,
|
||||
config.id
|
||||
)
|
||||
|
||||
return redis_client.exists(cooldown_cache_key)
|
||||
|
||||
@classmethod
|
||||
def get_config_in_cooldown_and_ttl(cls, tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
config_id: str) -> tuple[bool, int]:
|
||||
"""
|
||||
Get model load balancing config is in cooldown and ttl
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param model_type: model type
|
||||
:param model: model name
|
||||
:param config_id: model load balancing config id
|
||||
:return:
|
||||
"""
|
||||
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
||||
tenant_id,
|
||||
provider,
|
||||
model_type.value,
|
||||
model,
|
||||
config_id
|
||||
)
|
||||
|
||||
ttl = redis_client.ttl(cooldown_cache_key)
|
||||
if ttl == -2:
|
||||
return False, 0
|
||||
|
||||
return True, ttl
|
||||
|
||||
@ -20,7 +20,7 @@ This module provides the interface for invoking and authenticating various model
|
||||
|
||||

|
||||
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./schema.md).
|
||||
Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
|
||||
|
||||
- Selectable model list display
|
||||
|
||||
|
||||
@ -336,7 +336,7 @@ Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement
|
||||
- Invoke Invocation
|
||||
|
||||
```python
|
||||
def _invoke(elf, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
@ -376,7 +376,7 @@ class XinferenceProvider(Provider):
|
||||
- Invoke 调用
|
||||
|
||||
```python
|
||||
def _invoke(elf, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None):
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from core.helper.position_helper import get_position_map, sort_by_position_map
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
@ -17,7 +18,6 @@ from core.model_runtime.entities.model_entities import (
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.utils.position_helper import get_position_map, sort_by_position_map
|
||||
|
||||
|
||||
class AIModel(ABC):
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
from core.utils.module_import_helper import get_subclasses_from_module, import_module_from_source
|
||||
|
||||
|
||||
class ModelProvider(ABC):
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
- google
|
||||
- vertex_ai
|
||||
- nvidia
|
||||
- nvidia_nim
|
||||
- cohere
|
||||
- bedrock
|
||||
- togetherai
|
||||
@ -30,3 +31,5 @@
|
||||
- volcengine_maas
|
||||
- openai_api_compatible
|
||||
- deepseek
|
||||
- hunyuan
|
||||
- siliconflow
|
||||
|
||||
@ -53,6 +53,15 @@ model_credential_schema:
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: 2024-05-01-preview
|
||||
value: 2024-05-01-preview
|
||||
- label:
|
||||
en_US: 2024-04-01-preview
|
||||
value: 2024-04-01-preview
|
||||
- label:
|
||||
en_US: 2024-03-01-preview
|
||||
value: 2024-03-01-preview
|
||||
- label:
|
||||
en_US: 2024-02-15-preview
|
||||
value: 2024-02-15-preview
|
||||
|
||||
@ -6,7 +6,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4000
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -6,7 +6,7 @@ features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 192000
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
@ -0,0 +1,45 @@
|
||||
model: baichuan3-turbo-128k
|
||||
label:
|
||||
en_US: Baichuan3-Turbo-128k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 128000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
en_US: Search Enhance
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
@ -0,0 +1,45 @@
|
||||
model: baichuan3-turbo
|
||||
label:
|
||||
en_US: Baichuan3-Turbo
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
en_US: Search Enhance
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
@ -0,0 +1,45 @@
|
||||
model: baichuan4
|
||||
label:
|
||||
en_US: Baichuan4
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
en_US: Search Enhance
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
@ -51,26 +51,29 @@ class BaichuanModel:
|
||||
'baichuan2-turbo': 'Baichuan2-Turbo',
|
||||
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
|
||||
'baichuan2-53b': 'Baichuan2-53B',
|
||||
'baichuan3-turbo': 'Baichuan3-Turbo',
|
||||
'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
|
||||
'baichuan4': 'Baichuan4',
|
||||
}[model]
|
||||
|
||||
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
|
||||
resp = response.json()
|
||||
choices = resp.get('choices', [])
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
for choice in choices:
|
||||
message.content += choice['message']['content']
|
||||
message.role = choice['message']['role']
|
||||
if choice['finish_reason']:
|
||||
message.stop_reason = choice['finish_reason']
|
||||
resp = response.json()
|
||||
choices = resp.get('choices', [])
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
for choice in choices:
|
||||
message.content += choice['message']['content']
|
||||
message.role = choice['message']['role']
|
||||
if choice['finish_reason']:
|
||||
message.stop_reason = choice['finish_reason']
|
||||
|
||||
if 'usage' in resp:
|
||||
message.usage = {
|
||||
'prompt_tokens': resp['usage']['prompt_tokens'],
|
||||
'completion_tokens': resp['usage']['completion_tokens'],
|
||||
'total_tokens': resp['usage']['total_tokens'],
|
||||
}
|
||||
|
||||
return message
|
||||
if 'usage' in resp:
|
||||
message.usage = {
|
||||
'prompt_tokens': resp['usage']['prompt_tokens'],
|
||||
'completion_tokens': resp['usage']['completion_tokens'],
|
||||
'total_tokens': resp['usage']['total_tokens'],
|
||||
}
|
||||
|
||||
return message
|
||||
|
||||
def _handle_chat_stream_generate_response(self, response) -> Generator:
|
||||
for line in response.iter_lines():
|
||||
@ -110,7 +113,8 @@ class BaichuanModel:
|
||||
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
||||
parameters: dict[str, Any]) \
|
||||
-> dict[str, Any]:
|
||||
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
|
||||
@ -143,7 +147,8 @@ class BaichuanModel:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
# there is no secret key for turbo api
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
@ -160,7 +165,8 @@ class BaichuanModel:
|
||||
parameters: dict[str, Any], timeout: int) \
|
||||
-> Union[Generator, BaichuanMessage]:
|
||||
|
||||
if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b':
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
@ -7,6 +7,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
@ -32,20 +33,21 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
||||
|
||||
|
||||
class BaichuanLarguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int:
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
|
||||
"""Calculate num tokens for baichuan model"""
|
||||
|
||||
def tokens(text: str):
|
||||
return BaichuanTokenizer._get_num_tokens(text)
|
||||
|
||||
@ -85,9 +87,20 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": message.content
|
||||
}]
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
|
||||
return message_dict
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
@ -106,13 +119,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
if tools is not None and len(tools) > 0:
|
||||
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
||||
|
||||
|
||||
instance = BaichuanModel(
|
||||
api_key=credentials['api_key'],
|
||||
secret_key=credentials.get('secret_key', '')
|
||||
@ -129,11 +142,12 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
]
|
||||
|
||||
# invoke model
|
||||
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters, timeout=60)
|
||||
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
|
||||
timeout=60)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
||||
|
||||
|
||||
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str,
|
||||
@ -141,7 +155,9 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
credentials: dict,
|
||||
response: BaichuanMessage) -> LLMResult:
|
||||
# convert baichuan message to llm result
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens'])
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=response.usage['prompt_tokens'],
|
||||
completion_tokens=response.usage['completion_tokens'])
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -158,7 +174,9 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
||||
response: Generator[BaichuanMessage, None, None]) -> Generator:
|
||||
for message in response:
|
||||
if message.usage:
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens'])
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=message.usage['prompt_tokens'],
|
||||
completion_tokens=message.usage['completion_tokens'])
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
|
||||
@ -358,26 +358,25 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return message_dict
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, messages: list[PromptMessage] | str,
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param messages: prompt messages or message string
|
||||
:param prompt_messages: prompt messages or message string
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prefix = model.split('.')[0]
|
||||
model_name = model.split('.')[1]
|
||||
if isinstance(messages, str):
|
||||
prompt = messages
|
||||
if isinstance(prompt_messages, str):
|
||||
prompt = prompt_messages
|
||||
else:
|
||||
prompt = self._convert_messages_to_prompt(messages, prefix, model_name)
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages, prefix, model_name)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
|
||||
@ -59,15 +59,15 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
model_prefix = model.split('.')[0]
|
||||
|
||||
if model_prefix == "amazon" :
|
||||
for text in texts:
|
||||
body = {
|
||||
for text in texts:
|
||||
body = {
|
||||
"inputText": text,
|
||||
}
|
||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend([response_body.get('embedding')])
|
||||
token_usage += response_body.get('inputTextTokenCount')
|
||||
logger.warning(f'Total Tokens: {token_usage}')
|
||||
result = TextEmbeddingResult(
|
||||
}
|
||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend([response_body.get('embedding')])
|
||||
token_usage += response_body.get('inputTextTokenCount')
|
||||
logger.warning(f'Total Tokens: {token_usage}')
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(
|
||||
@ -75,20 +75,20 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
credentials=credentials,
|
||||
tokens=token_usage
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
)
|
||||
return result
|
||||
|
||||
if model_prefix == "cohere" :
|
||||
input_type = 'search_document' if len(texts) > 1 else 'search_query'
|
||||
for text in texts:
|
||||
body = {
|
||||
input_type = 'search_document' if len(texts) > 1 else 'search_query'
|
||||
for text in texts:
|
||||
body = {
|
||||
"texts": [text],
|
||||
"input_type": input_type,
|
||||
}
|
||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend(response_body.get('embeddings'))
|
||||
token_usage += len(text)
|
||||
result = TextEmbeddingResult(
|
||||
}
|
||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend(response_body.get('embeddings'))
|
||||
token_usage += len(text)
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(
|
||||
@ -96,9 +96,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
credentials=credentials,
|
||||
tokens=token_usage
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
)
|
||||
return result
|
||||
|
||||
#others
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
|
||||
@ -183,7 +183,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
|
||||
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
|
||||
"""
|
||||
Map client error to invoke error
|
||||
@ -212,9 +212,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
content_type = 'application/json'
|
||||
try:
|
||||
response = bedrock_runtime.invoke_model(
|
||||
body=json.dumps(body),
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
body=json.dumps(body),
|
||||
modelId=model,
|
||||
accept=accept,
|
||||
contentType=content_type
|
||||
)
|
||||
response_body = json.loads(response.get('body').read().decode('utf-8'))
|
||||
|
||||
@ -1,18 +1,22 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.api_core.exceptions as exceptions
|
||||
import google.generativeai as genai
|
||||
import google.generativeai.client as client
|
||||
import requests
|
||||
from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory
|
||||
from google.generativeai.types.content_types import to_part
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
@ -204,6 +208,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
stream=stream,
|
||||
safety_settings=safety_settings,
|
||||
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
|
||||
request_options={"timeout": 600}
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -360,11 +365,22 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
for c in message.content:
|
||||
if c.type == PromptMessageContentType.TEXT:
|
||||
glm_content['parts'].append(to_part(c.data))
|
||||
else:
|
||||
metadata, data = c.data.split(',', 1)
|
||||
mime_type = metadata.split(';', 1)[0].split(':')[1]
|
||||
blob = {"inline_data":{"mime_type":mime_type,"data":data}}
|
||||
elif c.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, c)
|
||||
if message_content.data.startswith("data:"):
|
||||
metadata, base64_data = c.data.split(',', 1)
|
||||
mime_type = metadata.split(';', 1)[0].split(':')[1]
|
||||
else:
|
||||
# fetch image data from url
|
||||
try:
|
||||
image_content = requests.get(message_content.data).content
|
||||
mime_type, _ = mimetypes.guess_type(message_content.data)
|
||||
base64_data = base64.b64encode(image_content).decode('utf-8')
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
|
||||
blob = {"inline_data":{"mime_type":mime_type,"data":base64_data}}
|
||||
glm_content['parts'].append(blob)
|
||||
|
||||
return glm_content
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
glm_content = {
|
||||
@ -443,4 +459,4 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
exceptions.RequestRangeNotSatisfiable,
|
||||
exceptions.Cancelled,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 67 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 59 KiB |
30
api/core/model_runtime/model_providers/hunyuan/hunyuan.py
Normal file
30
api/core/model_runtime/model_providers/hunyuan/hunyuan.py
Normal file
@ -0,0 +1,30 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class HunyuanProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `hunyuan-standard` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='hunyuan-standard',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
40
api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml
Normal file
40
api/core/model_runtime/model_providers/hunyuan/hunyuan.yaml
Normal file
@ -0,0 +1,40 @@
|
||||
provider: hunyuan
|
||||
label:
|
||||
zh_Hans: 腾讯混元
|
||||
en_US: Hunyuan
|
||||
description:
|
||||
en_US: Models provided by Tencent Hunyuan, such as hunyuan-standard, hunyuan-standard-256k, hunyuan-pro and hunyuan-lite.
|
||||
zh_Hans: 腾讯混元提供的模型,例如 hunyuan-standard、 hunyuan-standard-256k, hunyuan-pro 和 hunyuan-lite。
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#F6F7F7"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from Tencent Hunyuan
|
||||
zh_Hans: 从腾讯混元获取 API Key
|
||||
url:
|
||||
en_US: https://console.cloud.tencent.com/cam/capi
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: secret_id
|
||||
label:
|
||||
en_US: Secret ID
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 Secret ID
|
||||
en_US: Enter your Secret ID
|
||||
- variable: secret_key
|
||||
label:
|
||||
en_US: Secret Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 Secret Key
|
||||
en_US: Enter your Secret Key
|
||||
@ -0,0 +1,4 @@
|
||||
- hunyuan-lite
|
||||
- hunyuan-standard
|
||||
- hunyuan-standard-256k
|
||||
- hunyuan-pro
|
||||
@ -0,0 +1,28 @@
|
||||
model: hunyuan-lite
|
||||
label:
|
||||
zh_Hans: hunyuan-lite
|
||||
en_US: hunyuan-lite
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 256000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 256000
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,28 @@
|
||||
model: hunyuan-pro
|
||||
label:
|
||||
zh_Hans: hunyuan-pro
|
||||
en_US: hunyuan-pro
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
pricing:
|
||||
input: '0.03'
|
||||
output: '0.10'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,28 @@
|
||||
model: hunyuan-standard-256k
|
||||
label:
|
||||
zh_Hans: hunyuan-standard-256k
|
||||
en_US: hunyuan-standard-256k
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 256000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 256000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.06'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,28 @@
|
||||
model: hunyuan-standard
|
||||
label:
|
||||
zh_Hans: hunyuan-standard
|
||||
en_US: hunyuan-standard
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- multi-tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 32000
|
||||
pricing:
|
||||
input: '0.0045'
|
||||
output: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
205
api/core/model_runtime/model_providers/hunyuan/llm/llm.py
Normal file
205
api/core/model_runtime/model_providers/hunyuan/llm/llm.py
Normal file
@ -0,0 +1,205 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception import TencentCloudSDKException
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class HunyuanLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
|
||||
client = self._setup_hunyuan_client(credentials)
|
||||
request = models.ChatCompletionsRequest()
|
||||
messages_dict = self._convert_prompt_messages_to_dicts(prompt_messages)
|
||||
|
||||
custom_parameters = {
|
||||
'Temperature': model_parameters.get('temperature', 0.0),
|
||||
'TopP': model_parameters.get('top_p', 1.0)
|
||||
}
|
||||
|
||||
params = {
|
||||
"Model": model,
|
||||
"Messages": messages_dict,
|
||||
"Stream": stream,
|
||||
**custom_parameters,
|
||||
}
|
||||
|
||||
request.from_json_string(json.dumps(params))
|
||||
response = client.ChatCompletions(request)
|
||||
|
||||
if stream:
|
||||
return self._handle_stream_chat_response(model, credentials, prompt_messages, response)
|
||||
|
||||
return self._handle_chat_response(credentials, model, prompt_messages, response)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate credentials
|
||||
"""
|
||||
try:
|
||||
client = self._setup_hunyuan_client(credentials)
|
||||
|
||||
req = models.ChatCompletionsRequest()
|
||||
params = {
|
||||
"Model": model,
|
||||
"Messages": [{
|
||||
"Role": "user",
|
||||
"Content": "hello"
|
||||
}],
|
||||
"TopP": 1,
|
||||
"Temperature": 0,
|
||||
"Stream": False
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
client.ChatCompletions(req)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
|
||||
def _setup_hunyuan_client(self, credentials):
|
||||
secret_id = credentials['secret_id']
|
||||
secret_key = credentials['secret_key']
|
||||
cred = credential.Credential(secret_id, secret_key)
|
||||
httpProfile = HttpProfile()
|
||||
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
|
||||
clientProfile = ClientProfile()
|
||||
clientProfile.httpProfile = httpProfile
|
||||
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
|
||||
return client
|
||||
|
||||
def _convert_prompt_messages_to_dicts(self, prompt_messages: list[PromptMessage]) -> list[dict]:
|
||||
"""Convert a list of PromptMessage objects to a list of dictionaries with 'Role' and 'Content' keys."""
|
||||
return [{"Role": message.role.value, "Content": message.content} for message in prompt_messages]
|
||||
|
||||
def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp):
|
||||
for index, event in enumerate(resp):
|
||||
logging.debug("_handle_stream_chat_response, event: %s", event)
|
||||
|
||||
data_str = event['data']
|
||||
data = json.loads(data_str)
|
||||
|
||||
choices = data.get('Choices', [])
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0]
|
||||
delta = choice.get('Delta', {})
|
||||
message_content = delta.get('Content', '')
|
||||
finish_reason = choice.get('FinishReason', '')
|
||||
|
||||
usage = data.get('Usage', {})
|
||||
prompt_tokens = usage.get('PromptTokens', 0)
|
||||
completion_tokens = usage.get('CompletionTokens', 0)
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=message_content,
|
||||
tool_calls=[]
|
||||
)
|
||||
|
||||
delta_chunk = LLMResultChunkDelta(
|
||||
index=index,
|
||||
role=delta.get('Role', 'assistant'),
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=delta_chunk,
|
||||
)
|
||||
|
||||
def _handle_chat_response(self, credentials, model, prompt_messages, response):
|
||||
usage = self._calc_response_usage(model, credentials, response.Usage.PromptTokens,
|
||||
response.Usage.CompletionTokens)
|
||||
assistant_prompt_message = PromptMessage(role="assistant")
|
||||
assistant_prompt_message.content = response.Choices[0].Message.Content
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
if len(prompt_messages) == 0:
|
||||
return 0
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nHuman:"
|
||||
ai_prompt = "\n\nAssistant:"
|
||||
content = message.content
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = content
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeError: [TencentCloudSDKException],
|
||||
}
|
||||
@ -0,0 +1,9 @@
|
||||
model: jina-clip-v1
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 8192
|
||||
max_chunks: 2048
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -52,16 +52,21 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
def transform_jina_input_text(model, text):
|
||||
if model == 'jina-clip-v1':
|
||||
return {"text": text}
|
||||
return text
|
||||
|
||||
data = {
|
||||
'model': model,
|
||||
'input': texts
|
||||
'input': [transform_jina_input_text(model, text) for text in texts]
|
||||
}
|
||||
|
||||
try:
|
||||
response = post(url, headers=headers, data=dumps(data))
|
||||
except Exception as e:
|
||||
raise InvokeConnectionError(str(e))
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
@ -75,16 +80,19 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
else:
|
||||
raise InvokeBadRequestError(msg)
|
||||
except JSONDecodeError as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
try:
|
||||
resp = response.json()
|
||||
embeddings = resp['data']
|
||||
usage = resp['usage']
|
||||
except Exception as e:
|
||||
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
raise InvokeServerUnavailableError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage['total_tokens'])
|
||||
usage = self._calc_response_usage(
|
||||
model=model, credentials=credentials, tokens=usage['total_tokens'])
|
||||
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
@ -122,7 +130,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=['ping'])
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed: {e}')
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
@ -144,7 +153,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
InvokeBadRequestError
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
@ -185,7 +194,8 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size'))
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||
credentials.get('context_size'))
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
@ -50,14 +51,14 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
|
||||
class LocalAILarguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
class LocalAILanguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
@ -67,8 +68,9 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for baichuan model
|
||||
LocalAI does not supports
|
||||
LocalAI does not supports
|
||||
"""
|
||||
|
||||
def tokens(text: str):
|
||||
"""
|
||||
We cloud not determine which tokenizer to use, cause the model is customized.
|
||||
@ -124,7 +126,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
num_tokens += self._num_tokens_for_tools(tools)
|
||||
|
||||
return num_tokens
|
||||
|
||||
|
||||
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
|
||||
"""
|
||||
Calculate num tokens for tool calling
|
||||
@ -133,6 +135,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
"""
|
||||
|
||||
def tokens(text: str):
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
@ -193,7 +196,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
completion_model = LLMMode.COMPLETION.value
|
||||
else:
|
||||
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
|
||||
|
||||
|
||||
rules = [
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
@ -227,7 +230,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
)
|
||||
]
|
||||
|
||||
model_properties = {
|
||||
model_properties = {
|
||||
ModelPropertyKey.MODE: completion_model,
|
||||
} if completion_model else {}
|
||||
|
||||
@ -246,11 +249,11 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
|
||||
return entity
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
|
||||
|
||||
kwargs = self._to_client_kwargs(credentials)
|
||||
# init model client
|
||||
client = OpenAI(**kwargs)
|
||||
@ -271,7 +274,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
extra_model_kwargs['functions'] = [
|
||||
helper.dump_model(tool) for tool in tools
|
||||
]
|
||||
|
||||
|
||||
if completion_type == 'chat_completion':
|
||||
result = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
@ -294,24 +297,24 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
if stream:
|
||||
if completion_type == 'completion':
|
||||
return self._handle_completion_generate_stream_response(
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
return self._handle_chat_generate_stream_response(
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
|
||||
if completion_type == 'completion':
|
||||
return self._handle_completion_generate_response(
|
||||
model=model, credentials=credentials, response=result,
|
||||
model=model, credentials=credentials, response=result,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
return self._handle_chat_generate_response(
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
model=model, credentials=credentials, response=result, tools=tools,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
|
||||
def _to_client_kwargs(self, credentials: dict) -> dict:
|
||||
"""
|
||||
Convert invoke kwargs to client kwargs
|
||||
@ -321,7 +324,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
"""
|
||||
if not credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] += '/'
|
||||
|
||||
|
||||
client_kwargs = {
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"api_key": "1",
|
||||
@ -351,9 +354,20 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": message.content
|
||||
}]
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
|
||||
return message_dict
|
||||
|
||||
def _convert_prompt_message_to_completion_prompts(self, messages: list[PromptMessage]) -> str:
|
||||
@ -373,14 +387,14 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
prompts += f'{message.content}\n'
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
||||
|
||||
return prompts
|
||||
|
||||
def _handle_completion_generate_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Completion,
|
||||
) -> LLMResult:
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Completion,
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
@ -393,7 +407,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
"""
|
||||
if len(response.choices) == 0:
|
||||
raise InvokeServerUnavailableError("Empty response")
|
||||
|
||||
|
||||
assistant_message = response.choices[0].text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
@ -407,7 +421,8 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=[])
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
@ -436,7 +451,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
"""
|
||||
if len(response.choices) == 0:
|
||||
raise InvokeServerUnavailableError("Empty response")
|
||||
|
||||
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# convert function call to tool call
|
||||
@ -452,7 +467,8 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[assistant_prompt_message], tools=tools)
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens)
|
||||
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
@ -465,10 +481,10 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
return response
|
||||
|
||||
def _handle_completion_generate_stream_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Stream[Completion],
|
||||
tools: list[PromptMessageTool]) -> Generator:
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Stream[Completion],
|
||||
tools: list[PromptMessageTool]) -> Generator:
|
||||
full_response = ''
|
||||
|
||||
for chunk in response:
|
||||
@ -496,9 +512,9 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -538,7 +554,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
||||
continue
|
||||
|
||||
|
||||
# check if there is a tool call in the response
|
||||
function_calls = None
|
||||
if delta.delta.function_call:
|
||||
@ -562,9 +578,9 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools)
|
||||
completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[])
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
|
||||
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
@ -613,7 +629,7 @@ class LocalAILarguageModel(LargeLanguageModel):
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
return tool_calls
|
||||
return tool_calls
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
|
||||
@ -4,13 +4,13 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
|
||||
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 110 KiB |
@ -0,0 +1,3 @@
|
||||
<svg width="567" height="376" viewBox="0 0 567 376" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M58.0366 161.868C58.0366 161.868 109.261 86.2912 211.538 78.4724V51.053C98.2528 60.1511 0.152344 156.098 0.152344 156.098C0.152344 156.098 55.7148 316.717 211.538 331.426V302.282C97.1876 287.896 58.0366 161.868 58.0366 161.868ZM211.538 244.32V271.013C125.114 255.603 101.125 165.768 101.125 165.768C101.125 165.768 142.621 119.799 211.538 112.345V141.633C211.486 141.633 211.449 141.617 211.406 141.617C175.235 137.276 146.978 171.067 146.978 171.067C146.978 171.067 162.816 227.949 211.538 244.32ZM211.538 0.47998V51.053C214.864 50.7981 218.189 50.5818 221.533 50.468C350.326 46.1273 434.243 156.098 434.243 156.098C434.243 156.098 337.861 273.296 237.448 273.296C228.245 273.296 219.63 272.443 211.538 271.009V302.282C218.695 303.201 225.903 303.667 233.119 303.675C326.56 303.675 394.134 255.954 459.566 199.474C470.415 208.162 514.828 229.299 523.958 238.55C461.745 290.639 316.752 332.626 234.551 332.626C226.627 332.626 219.018 332.148 211.538 331.426V375.369H566.701V0.47998H211.538ZM211.538 112.345V78.4724C214.829 78.2425 218.146 78.0672 221.533 77.9602C314.148 75.0512 374.909 157.548 374.909 157.548C374.909 157.548 309.281 248.693 238.914 248.693C228.787 248.693 219.707 247.065 211.536 244.318V141.631C247.591 145.987 254.848 161.914 276.524 198.049L324.737 157.398C324.737 157.398 289.544 111.243 230.219 111.243C223.768 111.241 217.597 111.696 211.538 112.345Z" fill="#77B900"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
12
api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py
Normal file
12
api/core/model_runtime/model_providers/nvidia_nim/llm/llm.py
Normal file
@ -0,0 +1,12 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIANIMProvider(OAIAPICompatLargeLanguageModel):
|
||||
"""
|
||||
Model class for NVIDIA NIM large language model.
|
||||
"""
|
||||
pass
|
||||
@ -0,0 +1,11 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIANIMProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
||||
@ -0,0 +1,79 @@
|
||||
provider: nvidia_nim
|
||||
label:
|
||||
en_US: NVIDIA NIM
|
||||
description:
|
||||
en_US: NVIDIA NIM, a set of easy-to-use inference microservices.
|
||||
zh_Hans: NVIDIA NIM,一组易于使用的模型推理微服务。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#EFFDFD"
|
||||
help:
|
||||
title:
|
||||
en_US: Learn more about NVIDIA NIM
|
||||
zh_Hans: 了解 NVIDIA NIM 更多信息
|
||||
url:
|
||||
en_US: https://www.nvidia.com/en-us/ai/
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter full model name
|
||||
zh_Hans: 输入模型全称
|
||||
credential_form_schemas:
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
zh_Hans: API endpoint URL
|
||||
en_US: API endpoint URL
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: Base URL, e.g. http://192.168.1.100:8000/v1
|
||||
en_US: Base URL, e.g. http://192.168.1.100:8000/v1
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '4096'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens_to_sample
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: '4096'
|
||||
type: text-input
|
||||
@ -8,7 +8,12 @@ from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMMode,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
@ -40,7 +45,9 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.large_language_model import (
|
||||
LargeLanguageModel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -50,11 +57,17 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
Model class for Ollama large language model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -75,11 +88,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -100,10 +118,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if isinstance(first_prompt_message.content, str):
|
||||
text = first_prompt_message.content
|
||||
else:
|
||||
text = ''
|
||||
text = ""
|
||||
for message_content in first_prompt_message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
message_content = cast(
|
||||
TextPromptMessageContent, message_content
|
||||
)
|
||||
text = message_content.data
|
||||
break
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
@ -121,19 +141,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={
|
||||
'num_predict': 5
|
||||
},
|
||||
stream=False
|
||||
model_parameters={"num_predict": 5},
|
||||
stream=False,
|
||||
)
|
||||
except InvokeError as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
|
||||
raise CredentialsValidateFailedError(
|
||||
f"An error occurred during credentials validation: {ex.description}"
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
||||
raise CredentialsValidateFailedError(
|
||||
f"An error occurred during credentials validation: {str(ex)}"
|
||||
)
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm completion model
|
||||
|
||||
@ -146,76 +175,93 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
endpoint_url = credentials['base_url']
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
endpoint_url = credentials["base_url"]
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {
|
||||
'model': model,
|
||||
'stream': stream
|
||||
}
|
||||
data = {"model": model, "stream": stream}
|
||||
|
||||
if 'format' in model_parameters:
|
||||
data['format'] = model_parameters['format']
|
||||
del model_parameters['format']
|
||||
if "format" in model_parameters:
|
||||
data["format"] = model_parameters["format"]
|
||||
del model_parameters["format"]
|
||||
|
||||
data['options'] = model_parameters or {}
|
||||
if "keep_alive" in model_parameters:
|
||||
data["keep_alive"] = model_parameters["keep_alive"]
|
||||
del model_parameters["keep_alive"]
|
||||
|
||||
data["options"] = model_parameters or {}
|
||||
|
||||
if stop:
|
||||
data['stop'] = "\n".join(stop)
|
||||
data["stop"] = "\n".join(stop)
|
||||
|
||||
completion_type = LLMMode.value_of(credentials['mode'])
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, 'api/chat')
|
||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||
data["messages"] = [
|
||||
self._convert_prompt_message_to_dict(m) for m in prompt_messages
|
||||
]
|
||||
else:
|
||||
endpoint_url = urljoin(endpoint_url, 'api/generate')
|
||||
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||
first_prompt_message = prompt_messages[0]
|
||||
if isinstance(first_prompt_message, UserPromptMessage):
|
||||
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
|
||||
if isinstance(first_prompt_message.content, str):
|
||||
data['prompt'] = first_prompt_message.content
|
||||
data["prompt"] = first_prompt_message.content
|
||||
else:
|
||||
text = ''
|
||||
text = ""
|
||||
images = []
|
||||
for message_content in first_prompt_message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
message_content = cast(
|
||||
TextPromptMessageContent, message_content
|
||||
)
|
||||
text = message_content.data
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||
message_content = cast(
|
||||
ImagePromptMessageContent, message_content
|
||||
)
|
||||
image_data = re.sub(
|
||||
r"^data:image\/[a-zA-Z]+;base64,",
|
||||
"",
|
||||
message_content.data,
|
||||
)
|
||||
images.append(image_data)
|
||||
|
||||
data['prompt'] = text
|
||||
data['images'] = images
|
||||
data["prompt"] = text
|
||||
data["images"] = images
|
||||
|
||||
# send a post request to validate the credentials
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 300),
|
||||
stream=stream
|
||||
endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream
|
||||
)
|
||||
|
||||
response.encoding = "utf-8"
|
||||
if response.status_code != 200:
|
||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
||||
raise InvokeError(
|
||||
f"API request failed with status code {response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
||||
return self._handle_generate_stream_response(
|
||||
model, credentials, completion_type, response, prompt_messages
|
||||
)
|
||||
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
||||
return self._handle_generate_response(
|
||||
model, credentials, completion_type, response, prompt_messages
|
||||
)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
def _handle_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm completion response
|
||||
|
||||
@ -229,14 +275,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
response_json = response.json()
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
message = response_json.get('message', {})
|
||||
response_content = message.get('content', '')
|
||||
message = response_json.get("message", {})
|
||||
response_content = message.get("content", "")
|
||||
else:
|
||||
response_content = response_json['response']
|
||||
response_content = response_json["response"]
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response_content)
|
||||
|
||||
if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
|
||||
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
||||
# transform usage
|
||||
prompt_tokens = response_json["prompt_eval_count"]
|
||||
completion_tokens = response_json["eval_count"]
|
||||
@ -246,7 +292,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
@ -258,8 +306,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm completion stream response
|
||||
|
||||
@ -270,17 +324,20 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
full_text = ''
|
||||
full_text = ""
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
||||
-> LLMResultChunk:
|
||||
def create_final_llm_result_chunk(
|
||||
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
@ -289,11 +346,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
@ -304,7 +361,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
finish_reason="Non-JSON encountered.",
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
@ -314,55 +371,57 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if not chunk_json:
|
||||
continue
|
||||
|
||||
if 'message' not in chunk_json:
|
||||
text = ''
|
||||
if "message" not in chunk_json:
|
||||
text = ""
|
||||
else:
|
||||
text = chunk_json.get('message').get('content', '')
|
||||
text = chunk_json.get("message").get("content", "")
|
||||
else:
|
||||
if not chunk_json:
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = chunk_json['response']
|
||||
text = chunk_json["response"]
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_text += text
|
||||
|
||||
if chunk_json['done']:
|
||||
if chunk_json["done"]:
|
||||
# calculate num tokens
|
||||
if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
|
||||
if "prompt_eval_count" in chunk_json and "eval_count" in chunk_json:
|
||||
# transform usage
|
||||
prompt_tokens = chunk_json["prompt_eval_count"]
|
||||
completion_tokens = chunk_json["eval_count"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(
|
||||
prompt_messages[0].content
|
||||
)
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk_json['model'],
|
||||
model=chunk_json["model"],
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason='stop',
|
||||
usage=usage
|
||||
)
|
||||
finish_reason="stop",
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk_json['model'],
|
||||
model=chunk_json["model"],
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
@ -376,15 +435,21 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
text = ''
|
||||
text = ""
|
||||
images = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
message_content = cast(
|
||||
TextPromptMessageContent, message_content
|
||||
)
|
||||
text = message_content.data
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||
message_content = cast(
|
||||
ImagePromptMessageContent, message_content
|
||||
)
|
||||
image_data = re.sub(
|
||||
r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data
|
||||
)
|
||||
images.append(image_data)
|
||||
|
||||
message_dict = {"role": "user", "content": text, "images": images}
|
||||
@ -414,7 +479,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
|
||||
@ -425,20 +492,19 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
extras = {}
|
||||
|
||||
if 'vision_support' in credentials and credentials['vision_support'] == 'true':
|
||||
extras['features'] = [ModelFeature.VISION]
|
||||
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||
extras["features"] = [ModelFeature.VISION]
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
zh_Hans=model,
|
||||
en_US=model
|
||||
),
|
||||
label=I18nObject(zh_Hans=model, en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
|
||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||
credentials.get("context_size", 4096)
|
||||
),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
@ -446,152 +512,195 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
use_template=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"),
|
||||
help=I18nObject(
|
||||
en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"
|
||||
),
|
||||
default=0.1,
|
||||
min=0,
|
||||
max=1
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
use_template=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"),
|
||||
help=I18nObject(
|
||||
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"
|
||||
),
|
||||
default=0.9,
|
||||
min=0,
|
||||
max=1
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Reduces the probability of generating nonsense. "
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
|
||||
help=I18nObject(
|
||||
en_US="Reduces the probability of generating nonsense. "
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
||||
),
|
||||
min=1,
|
||||
max=100
|
||||
max=100,
|
||||
),
|
||||
ParameterRule(
|
||||
name='repeat_penalty',
|
||||
name="repeat_penalty",
|
||||
label=I18nObject(en_US="Repeat Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
|
||||
help=I18nObject(
|
||||
en_US="Sets how strongly to penalize repetitions. "
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
||||
),
|
||||
min=-2,
|
||||
max=2
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_predict',
|
||||
use_template='max_tokens',
|
||||
name="num_predict",
|
||||
use_template="max_tokens",
|
||||
label=I18nObject(en_US="Num Predict"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"),
|
||||
default=512 if int(credentials.get('max_tokens', 4096)) >= 768 else 128,
|
||||
help=I18nObject(
|
||||
en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||
),
|
||||
default=(
|
||||
512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128
|
||||
),
|
||||
min=-2,
|
||||
max=int(credentials.get('max_tokens', 4096)),
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
),
|
||||
ParameterRule(
|
||||
name='mirostat',
|
||||
name="mirostat",
|
||||
label=I18nObject(en_US="Mirostat sampling"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
|
||||
help=I18nObject(
|
||||
en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
||||
),
|
||||
min=0,
|
||||
max=2
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name='mirostat_eta',
|
||||
name="mirostat_eta",
|
||||
label=I18nObject(en_US="Mirostat Eta"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
|
||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"),
|
||||
precision=1
|
||||
help=I18nObject(
|
||||
en_US="Influences how quickly the algorithm responds to feedback from "
|
||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"
|
||||
),
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='mirostat_tau',
|
||||
name="mirostat_tau",
|
||||
label=I18nObject(en_US="Mirostat Tau"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"),
|
||||
precision=1
|
||||
help=I18nObject(
|
||||
en_US="Controls the balance between coherence and diversity of the output. "
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
||||
),
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_ctx',
|
||||
name="num_ctx",
|
||||
label=I18nObject(en_US="Size of context window"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
|
||||
"(Default: 2048)"),
|
||||
help=I18nObject(
|
||||
en_US="Sets the size of the context window used to generate the next token. "
|
||||
"(Default: 2048)"
|
||||
),
|
||||
default=2048,
|
||||
min=1
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_gpu',
|
||||
label=I18nObject(en_US="Num GPU"),
|
||||
label=I18nObject(en_US="GPU Layers"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="The number of layers to send to the GPU(s). "
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."),
|
||||
min=0,
|
||||
max=1
|
||||
help=I18nObject(en_US="The number of layers to offload to the GPU(s). "
|
||||
"On macOS it defaults to 1 to enable metal support, 0 to disable."
|
||||
"As long as a model fits into one gpu it stays in one. "
|
||||
"It does not set the number of GPU(s). "),
|
||||
min=-1,
|
||||
default=1
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_thread',
|
||||
name="num_thread",
|
||||
label=I18nObject(en_US="Num Thread"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets the number of threads to use during computation. "
|
||||
"By default, Ollama will detect this for optimal performance. "
|
||||
"It is recommended to set this value to the number of physical CPU cores "
|
||||
"your system has (as opposed to the logical number of cores)."),
|
||||
help=I18nObject(
|
||||
en_US="Sets the number of threads to use during computation. "
|
||||
"By default, Ollama will detect this for optimal performance. "
|
||||
"It is recommended to set this value to the number of physical CPU cores "
|
||||
"your system has (as opposed to the logical number of cores)."
|
||||
),
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='repeat_last_n',
|
||||
name="repeat_last_n",
|
||||
label=I18nObject(en_US="Repeat last N"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"),
|
||||
min=-1
|
||||
help=I18nObject(
|
||||
en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
||||
),
|
||||
min=-1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='tfs_z',
|
||||
name="tfs_z",
|
||||
label=I18nObject(en_US="TFS Z"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"),
|
||||
precision=1
|
||||
help=I18nObject(
|
||||
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"
|
||||
),
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"),
|
||||
help=I18nObject(
|
||||
en_US="Sets the random number seed to use for generation. Setting this to "
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name='format',
|
||||
name="keep_alive",
|
||||
label=I18nObject(en_US="Keep Alive"),
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(
|
||||
en_US="Sets how long the model is kept in memory after generating a response. "
|
||||
"This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours). "
|
||||
"A negative number keeps the model loaded indefinitely, and '0' unloads the model immediately after generating a response. "
|
||||
"Valid time units are 's','m','h'. (Default: 5m)"
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name="format",
|
||||
label=I18nObject(en_US="Format"),
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(en_US="the format to return a response in."
|
||||
" Currently the only accepted value is json."),
|
||||
options=['json'],
|
||||
)
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in."
|
||||
" Currently the only accepted value is json."
|
||||
),
|
||||
options=["json"],
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
output=Decimal(credentials.get('output_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
output=Decimal(credentials.get("output_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
**extras
|
||||
**extras,
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -619,10 +728,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
requests.exceptions.ConnectionError, # Engine Overloaded
|
||||
requests.exceptions.HTTPError # Server Error
|
||||
requests.exceptions.HTTPError, # Server Error
|
||||
],
|
||||
InvokeConnectionError: [
|
||||
requests.exceptions.ConnectTimeout, # Timeout
|
||||
requests.exceptions.ReadTimeout # Timeout
|
||||
]
|
||||
requests.exceptions.ReadTimeout, # Timeout
|
||||
],
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ features:
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
context_size: 16385
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 5.6 KiB |
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" fill="none" version="1.1" width="128" height="128" viewBox="0 0 128 128"><g><g style="opacity:0;"><rect x="0" y="0" width="128" height="128" rx="0" fill="#FFFFFF" fill-opacity="1"/></g><g><path d="M100.74,12L93.2335,12C69.21260000000001,12,55.3672,27.3468,55.3672,50.8672L55.3672,54.8988C52.6011,54.1056,49.7377,53.7031,46.8601,53.7031C29.816499999999998,53.7031,16,67.5196,16,84.5632C16,101.6069,29.816499999999998,115.423,46.8601,115.423C63.9037,115.423,77.72030000000001,101.6069,77.72030000000001,84.5632C77.72030000000001,82.4902,77.51140000000001,80.4223,77.0967,78.3911L77.2197,78.3911L100.74,78.3911C106.9654,78.3681,112,73.3151,112,67.08959999999999C112,60.8642,106.9654,55.8111,100.74,55.7882L100.7362,55.7882L100.6985,55.7879L100.6606,55.7882L77.2197,55.7882L77.2195,49.8663C77.2195,40.8584,83.7252,34.352900000000005,93.2335,34.352900000000005L100.5653,34.352900000000005L100.5733,34.352900000000005L100.5812,34.352900000000005L100.74,34.352900000000005L100.74,34.352900000000005C106.8469,34.2605,111.7497,29.284,111.7497,23.1764C111.7497,17.06889,106.8469,12.0923454,100.74,12L100.74,12ZM56.0347,84.5632C56.0347,79.4962,51.9271,75.3885,46.8601,75.3885C41.793099999999995,75.3885,37.6854,79.4962,37.6854,84.5632C37.6854,89.6303,41.793099999999995,93.7378,46.8601,93.7378C51.9271,93.7378,56.0347,89.6303,56.0347,84.5632Z" fill-rule="evenodd" fill="#8358F6" fill-opacity="1"/></g></g></svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
@ -0,0 +1,8 @@
|
||||
- deepseek-v2-chat
|
||||
- qwen2-72b-instruct
|
||||
- qwen2-57b-a14b-instruct
|
||||
- qwen2-7b-instruct
|
||||
- yi-1.5-34b-chat
|
||||
- yi-1.5-9b-chat
|
||||
- yi-1.5-6b-chat
|
||||
- glm4-9B-chat
|
||||
@ -0,0 +1,32 @@
|
||||
model: deepseek-ai/deepseek-v2-chat
|
||||
label:
|
||||
en_US: deepseek-ai/deepseek-v2-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '1.33'
|
||||
output: '1.33'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,32 @@
|
||||
model: zhipuai/glm4-9B-chat
|
||||
label:
|
||||
en_US: zhipuai/glm4-9B-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.6'
|
||||
output: '0.6'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,25 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@classmethod
|
||||
def _add_custom_parameters(cls, credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.siliconflow.cn/v1'
|
||||
@ -0,0 +1,32 @@
|
||||
model: alibaba/Qwen2-57B-A14B-Instruct
|
||||
label:
|
||||
en_US: alibaba/Qwen2-57B-A14B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '1.26'
|
||||
output: '1.26'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,32 @@
|
||||
model: alibaba/Qwen2-72B-Instruct
|
||||
label:
|
||||
en_US: alibaba/Qwen2-72B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '4.13'
|
||||
output: '4.13'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,32 @@
|
||||
model: alibaba/Qwen2-7B-Instruct
|
||||
label:
|
||||
en_US: alibaba/Qwen2-7B-Instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.35'
|
||||
output: '0.35'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,32 @@
|
||||
model: 01-ai/Yi-1.5-34B-Chat
|
||||
label:
|
||||
en_US: 01-ai/Yi-1.5-34B-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '1.26'
|
||||
output: '1.26'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,32 @@
|
||||
model: 01-ai/Yi-1.5-6B-Chat
|
||||
label:
|
||||
en_US: 01-ai/Yi-1.5-6B-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.35'
|
||||
output: '0.35'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,32 @@
|
||||
model: 01-ai/Yi-1.5-9B-Chat
|
||||
label:
|
||||
en_US: 01-ai/Yi-1.5-9B-Chat
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 16384
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
pricing:
|
||||
input: '0.42'
|
||||
output: '0.42'
|
||||
unit: '0.000001'
|
||||
currency: RMB
|
||||
@ -0,0 +1,29 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SiliconflowProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='deepseek-ai/deepseek-v2-chat',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
@ -0,0 +1,29 @@
|
||||
provider: siliconflow
|
||||
label:
|
||||
zh_Hans: 硅基流动
|
||||
en_US: SiliconFlow
|
||||
icon_small:
|
||||
en_US: siliconflow_square.svg
|
||||
icon_large:
|
||||
en_US: siliconflow.svg
|
||||
background: "#ffecff"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from SiliconFlow
|
||||
zh_Hans: 从 SiliconFlow 获取 API Key
|
||||
url:
|
||||
en_US: https://cloud.siliconflow.cn/keys
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
@ -0,0 +1,56 @@
|
||||
model: claude-3-haiku@20240307
|
||||
label:
|
||||
en_US: Claude 3 Haiku
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00125'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user