mirror of
https://github.com/langgenius/dify.git
synced 2026-02-01 01:17:00 +08:00
Compare commits
85 Commits
feat/class
...
0.15.8
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b938ab18d | |||
| 88356de923 | |||
| 5f09900dca | |||
| 9ac99abf20 | |||
| 32588f562e | |||
| 36f8bd3f1a | |||
| c919074e06 | |||
| 88cd9aedb7 | |||
| 16a4f77fb4 | |||
| 3401c52665 | |||
| 4fa3d78ed8 | |||
| 5f7f851b17 | |||
| 559ab46ee1 | |||
| df98223c8c | |||
| 144f9507f8 | |||
| 2e097a1ac0 | |||
| 9f7d8a981f | |||
| 40b31bafd5 | |||
| d38a2c95fb | |||
| 7d18e2a0ef | |||
| 024f242251 | |||
| bfdce78ca5 | |||
| 00c2258352 | |||
| a1b3d41712 | |||
| b26e20fe34 | |||
| 161ff432f1 | |||
| 99a9def623 | |||
| fe1846c437 | |||
| 8e75eb5c63 | |||
| 970508fcb6 | |||
| 9283a5414f | |||
| 2a2a0e9be9 | |||
| 061a765b7d | |||
| acd7fead87 | |||
| bbb080d5b2 | |||
| c01d8a70f3 | |||
| 1ca15989e0 | |||
| 8b5a3a9424 | |||
| 42ddcf1edd | |||
| 21561df10f | |||
| 0e33a3aa5f | |||
| d3895bcd6b | |||
| eeb390650b | |||
| ca19bd31d4 | |||
| 413dfd5628 | |||
| f9515901cc | |||
| 3f42fabff8 | |||
| 1caa578771 | |||
| b7c11c1818 | |||
| 3eb3db0663 | |||
| be46f32056 | |||
| 6e5c915f96 | |||
| 04d13a8116 | |||
| e638ede3f2 | |||
| 2348abe4bf | |||
| f7e7a399d9 | |||
| ba91f34636 | |||
| 16865d43a8 | |||
| 0d13aee15c | |||
| 49b4144ffd | |||
| 186e2d972e | |||
| 40dd63ecef | |||
| 6d66d6da15 | |||
| 03ec3513f3 | |||
| 87763fc234 | |||
| f6c44cae2e | |||
| da2ee04fce | |||
| 7673c36af3 | |||
| 9457b2af2f | |||
| 7203991032 | |||
| 5a685f7156 | |||
| a6a25030ad | |||
| 00458a31d5 | |||
| c6ddf6d6cc | |||
| 34b21b3065 | |||
| 8fbb355cd2 | |||
| e8b3b7e578 | |||
| 59ca44f493 | |||
| 9e1457c2c3 | |||
| fac83e14bc | |||
| a97cec57e4 | |||
| 38c10b47d3 | |||
| 1a2523fd15 | |||
| 03243cb422 | |||
| 2ad7ee0344 |
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
@ -5,8 +5,8 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
release:
|
||||
types: [published]
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
concurrency:
|
||||
group: build-push-${{ github.head_ref || github.run_id }}
|
||||
|
||||
47
.github/workflows/docker-build.yml
vendored
Normal file
47
.github/workflows/docker-build.yml
vendored
Normal file
@ -0,0 +1,47 @@
|
||||
name: Build docker image
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/Dockerfile
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
context: "api"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
context: "api"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
context: "web"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
push: false
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
4
.markdownlint.json
Normal file
4
.markdownlint.json
Normal file
@ -0,0 +1,4 @@
|
||||
{
|
||||
"MD024": false,
|
||||
"MD013": false
|
||||
}
|
||||
45
CHANGELOG.md
Normal file
45
CHANGELOG.md
Normal file
@ -0,0 +1,45 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to Dify will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.15.8] - 2025-05-30
|
||||
|
||||
### Added
|
||||
|
||||
- Added gunicorn keepalive setting (#19537)
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed database configuration to allow DB_EXTRAS to set search_path via options (#16a4f77)
|
||||
- Fixed frontend third-party package security issues (#19655)
|
||||
- Updated dependencies: huggingface-hub (~0.16.4 to ~0.31.0), transformers (~4.35.0 to ~4.39.0), and resend (~0.7.0 to ~2.9.0) (#19563)
|
||||
- Downgrade boto3 from 1.36 to 1.35 (#19736)
|
||||
|
||||
## [0.15.7] - 2025-04-27
|
||||
|
||||
### Added
|
||||
|
||||
- Added support for GPT-4.1 in model providers (#18912)
|
||||
- Added support for Amazon Bedrock DeepSeek-R1 model (#18908)
|
||||
- Added support for Amazon Bedrock Claude Sonnet 3.7 model (#18788)
|
||||
- Refined version compatibility logic in app DSL service
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed issue with creating apps from template categories (#18807, #18868)
|
||||
- Fixed DSL version check when creating apps from explore templates (#18872, #18878)
|
||||
|
||||
## [0.15.6] - 2025-04-22
|
||||
|
||||
### Security
|
||||
|
||||
- Fixed clickjacking vulnerability (#18552)
|
||||
- Fixed reset password security issue (#18366)
|
||||
- Updated reset password token when email code verification succeeds (#18362)
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Vertex AI Gemini 2.0 Flash 001 schema (#18405)
|
||||
@ -25,6 +25,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="seguir en X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="seguir en LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Descargas de Docker" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="suivre sur X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="suivre sur LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Tirages Docker" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="X(Twitter)でフォロー"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="LinkedInでフォロー"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -25,6 +25,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -22,6 +22,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="X(Twitter)'da takip et"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="LinkedIn'da takip et"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Çekmeleri" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
@ -62,8 +65,6 @@ Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edi
|
||||

|
||||
|
||||
|
||||
Özür dilerim, haklısınız. Daha anlamlı ve akıcı bir çeviri yapmaya çalışayım. İşte güncellenmiş çeviri:
|
||||
|
||||
**3. Prompt IDE**:
|
||||
Komut istemlerini oluşturmak, model performansını karşılaştırmak ve sohbet tabanlı uygulamalara metin-konuşma gibi ek özellikler eklemek için kullanıcı dostu bir arayüz.
|
||||
|
||||
@ -150,8 +151,6 @@ Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edi
|
||||
## Dify'ı Kullanma
|
||||
|
||||
- **Cloud </br>**
|
||||
İşte verdiğiniz metnin Türkçe çevirisi, kod bloğu içinde:
|
||||
-
|
||||
Herkesin sıfır kurulumla denemesi için bir [Dify Cloud](https://dify.ai) hizmeti sunuyoruz. Bu hizmet, kendi kendine dağıtılan versiyonun tüm yeteneklerini sağlar ve sandbox planında 200 ücretsiz GPT-4 çağrısı içerir.
|
||||
|
||||
- **Dify Topluluk Sürümünü Kendi Sunucunuzda Barındırma</br>**
|
||||
@ -177,8 +176,6 @@ GitHub'da Dify'a yıldız verin ve yeni sürümlerden anında haberdar olun.
|
||||
>- RAM >= 4GB
|
||||
|
||||
</br>
|
||||
İşte verdiğiniz metnin Türkçe çevirisi, kod bloğu içinde:
|
||||
|
||||
Dify sunucusunu başlatmanın en kolay yolu, [docker-compose.yml](docker/docker-compose.yaml) dosyamızı çalıştırmaktır. Kurulum komutunu çalıştırmadan önce, makinenizde [Docker](https://docs.docker.com/get-docker/) ve [Docker Compose](https://docs.docker.com/compose/install/)'un kurulu olduğundan emin olun:
|
||||
|
||||
```bash
|
||||
|
||||
@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="theo dõi trên X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="theo dõi trên LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@ -430,4 +430,7 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
||||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
# Lockout duration in seconds
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
LOGIN_LOCKOUT_DURATION=86400
|
||||
|
||||
# Prevent Clickjacking
|
||||
ALLOW_EMBED=false
|
||||
@ -48,18 +48,18 @@ ENV TZ=UTC
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \
|
||||
&& echo "deb http://deb.debian.org/debian bookworm main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
# install libmagic to support the use of python-magic guess MIMETYPE
|
||||
&& apt-get install -y libmagic1 \
|
||||
RUN \
|
||||
apt-get update \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
fonts-noto-cjk \
|
||||
# install libmagic to support the use of python-magic guess MIMETYPE
|
||||
libmagic1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@ -78,7 +78,6 @@ COPY . /app/api/
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA=${COMMIT_SHA}
|
||||
|
||||
|
||||
@ -1,9 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt
|
||||
from pydantic import Field, NonNegativeInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class HostedCreditConfig(BaseSettings):
|
||||
HOSTED_MODEL_CREDIT_CONFIG: str = Field(
|
||||
description="Model credit configuration in format 'model:credits,model:credits', e.g., 'gpt-4:20,gpt-4o:10'",
|
||||
default="",
|
||||
)
|
||||
|
||||
def get_model_credits(self, model_name: str) -> int:
|
||||
"""
|
||||
Get credit value for a specific model name.
|
||||
Returns 1 if model is not found in configuration (default credit).
|
||||
|
||||
:param model_name: The name of the model to search for
|
||||
:return: The credit value for the model
|
||||
"""
|
||||
if not self.HOSTED_MODEL_CREDIT_CONFIG:
|
||||
return 1
|
||||
|
||||
try:
|
||||
credit_map = dict(
|
||||
item.strip().split(":", 1) for item in self.HOSTED_MODEL_CREDIT_CONFIG.split(",") if ":" in item
|
||||
)
|
||||
|
||||
# Search for matching model pattern
|
||||
for pattern, credit in credit_map.items():
|
||||
if pattern.strip() == model_name:
|
||||
return int(credit)
|
||||
return 1 # Default quota if no match found
|
||||
except (ValueError, AttributeError):
|
||||
return 1 # Return default quota if parsing fails
|
||||
|
||||
|
||||
class HostedOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for hosted OpenAI service
|
||||
@ -202,5 +233,7 @@ class HostedServiceConfig(
|
||||
HostedZhipuAIConfig,
|
||||
# moderation
|
||||
HostedModerationConfig,
|
||||
# credit config
|
||||
HostedCreditConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import quote_plus
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
@ -166,14 +166,28 @@ class DatabaseConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
|
||||
options = db_extras_dict.get("options", "")
|
||||
# Always include timezone
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
# Merge user options and timezone
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
return {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
|
||||
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": {"options": "-c timezone=UTC"},
|
||||
"connect_args": connect_args,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.15.2",
|
||||
default="0.15.8",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@ -8,7 +8,7 @@ from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError
|
||||
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
@ -22,6 +22,7 @@ from services.feature_service import FeatureService
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
@ -53,6 +54,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
@ -72,11 +74,20 @@ class ForgotPasswordCheckApi(Resource):
|
||||
if args["code"] != token_data.get("code"):
|
||||
raise EmailCodeError()
|
||||
|
||||
return {"is_valid": True, "email": token_data.get("email")}
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
@ -95,6 +106,9 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_reset_password_token(token)
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ from controllers.console.error import (
|
||||
EmailSendIpLimitError,
|
||||
NotAllowedCreateWorkspace,
|
||||
)
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
@ -38,6 +38,7 @@ class LoginApi(Resource):
|
||||
"""Resource for user login."""
|
||||
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
@ -110,6 +111,7 @@ class LogoutApi(Resource):
|
||||
|
||||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
|
||||
@ -154,3 +154,16 @@ def enterprise_license_required(view):
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def email_password_login_enabled(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
if features.enable_email_password_login:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# otherwise, return 403
|
||||
abort(403)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -104,7 +104,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
|
||||
@ -84,7 +84,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
|
||||
# invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
|
||||
@ -55,20 +55,6 @@ class AgentChatAppRunner(AppRunner):
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
|
||||
@ -15,10 +15,8 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
@ -31,106 +29,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(
|
||||
self,
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: Mapping[str, str],
|
||||
files: Sequence["File"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
:param app_record: app record
|
||||
:param model_config: model config entity
|
||||
:param prompt_template_entity: prompt template entity
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
# get prompt messages without memory and context
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
)
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError(
|
||||
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size."
|
||||
)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recalc_llm_max_tokens(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
|
||||
):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if prompt_tokens + max_tokens > model_context_tokens:
|
||||
max_tokens = max(model_context_tokens - prompt_tokens, 16)
|
||||
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
model_config.parameters[parameter_rule.name] = max_tokens
|
||||
|
||||
def organize_prompt_messages(
|
||||
self,
|
||||
app_record: App,
|
||||
|
||||
@ -50,20 +50,6 @@ class ChatAppRunner(AppRunner):
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
@ -194,9 +180,6 @@ class ChatAppRunner(AppRunner):
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
|
||||
@ -43,20 +43,6 @@ class CompletionAppRunner(AppRunner):
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
@ -152,9 +138,6 @@ class CompletionAppRunner(AppRunner):
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
|
||||
@ -11,15 +11,6 @@ from configs import dify_config
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
proxy_mounts = (
|
||||
{
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL
|
||||
else None
|
||||
)
|
||||
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
@ -51,7 +42,11 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
elif proxy_mounts:
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
with httpx.Client(mounts=proxy_mounts) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
else:
|
||||
|
||||
@ -26,7 +26,7 @@ class TokenBufferMemory:
|
||||
self.model_instance = model_instance
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
self, max_token_limit: int = 100000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
@ -23,6 +23,7 @@ __all__ = [
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
"ImagePromptMessageContent",
|
||||
"LLMMode",
|
||||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -8,7 +8,7 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
|
||||
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
|
||||
|
||||
|
||||
class LLMMode(Enum):
|
||||
class LLMMode(StrEnum):
|
||||
"""
|
||||
Enum class for large language model mode.
|
||||
"""
|
||||
|
||||
@ -30,6 +30,11 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HTML_THINKING_TAG = (
|
||||
'<details style="color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;" open> '
|
||||
"<summary> Thinking... </summary>"
|
||||
)
|
||||
|
||||
|
||||
class LargeLanguageModel(AIModel):
|
||||
"""
|
||||
@ -400,6 +405,40 @@ if you are not sure about the structure.
|
||||
),
|
||||
)
|
||||
|
||||
def _wrap_thinking_by_reasoning_content(self, delta: dict, is_reasoning: bool) -> tuple[str, bool]:
|
||||
"""
|
||||
If the reasoning response is from delta.get("reasoning_content"), we wrap
|
||||
it with HTML details tag.
|
||||
|
||||
:param delta: delta dictionary from LLM streaming response
|
||||
:param is_reasoning: is reasoning
|
||||
:return: tuple of (processed_content, is_reasoning)
|
||||
"""
|
||||
|
||||
content = delta.get("content") or ""
|
||||
reasoning_content = delta.get("reasoning_content")
|
||||
|
||||
if reasoning_content:
|
||||
if not is_reasoning:
|
||||
content = HTML_THINKING_TAG + reasoning_content
|
||||
is_reasoning = True
|
||||
else:
|
||||
content = reasoning_content
|
||||
elif is_reasoning:
|
||||
content = "</details>" + content
|
||||
is_reasoning = False
|
||||
return content, is_reasoning
|
||||
|
||||
def _wrap_thinking_by_tag(self, content: str) -> str:
|
||||
"""
|
||||
if the reasoning response is a <think>...</think> block from delta.get("content"),
|
||||
we replace <think> to <detail>.
|
||||
|
||||
:param content: delta.get("content")
|
||||
:return: processed_content
|
||||
"""
|
||||
return content.replace("<think>", HTML_THINKING_TAG).replace("</think>", "</details>")
|
||||
|
||||
def _invoke_result_generator(
|
||||
self,
|
||||
model: str,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
- openai
|
||||
- deepseek
|
||||
- anthropic
|
||||
- azure_openai
|
||||
- google
|
||||
@ -32,7 +33,6 @@
|
||||
- localai
|
||||
- volcengine_maas
|
||||
- openai_api_compatible
|
||||
- deepseek
|
||||
- hunyuan
|
||||
- siliconflow
|
||||
- perfxcloud
|
||||
|
||||
@ -51,6 +51,40 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
type: text-input
|
||||
default: "4096"
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: jwt_token
|
||||
required: true
|
||||
label:
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from azure.ai.inference import ChatCompletionsClient
|
||||
from azure.ai.inference.models import StreamingChatCompletionsUpdate
|
||||
from azure.ai.inference.models import StreamingChatCompletionsUpdate, SystemMessage, UserMessage
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import (
|
||||
ClientAuthenticationError,
|
||||
@ -20,7 +20,7 @@ from azure.core.exceptions import (
|
||||
)
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -30,6 +30,7 @@ from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
I18nObject,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
@ -60,10 +61,10 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[Sequence[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
@ -82,8 +83,8 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
|
||||
if not self.client:
|
||||
endpoint = credentials.get("endpoint")
|
||||
api_key = credentials.get("api_key")
|
||||
endpoint = str(credentials.get("endpoint"))
|
||||
api_key = str(credentials.get("api_key"))
|
||||
self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
|
||||
|
||||
messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages]
|
||||
@ -94,6 +95,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||
"temperature": model_parameters.get("temperature", 0),
|
||||
"top_p": model_parameters.get("top_p", 1),
|
||||
"stream": stream,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
if stop:
|
||||
@ -255,10 +257,16 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
endpoint = credentials.get("endpoint")
|
||||
api_key = credentials.get("api_key")
|
||||
endpoint = str(credentials.get("endpoint"))
|
||||
api_key = str(credentials.get("api_key"))
|
||||
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
|
||||
client.get_model_info()
|
||||
client.complete(
|
||||
messages=[
|
||||
SystemMessage(content="I say 'ping', you say 'pong'"),
|
||||
UserMessage(content="ping"),
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@ -327,7 +335,10 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.LLM,
|
||||
features=[],
|
||||
model_properties={},
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")),
|
||||
ModelPropertyKey.MODE: credentials.get("mode", LLMMode.CHAT),
|
||||
},
|
||||
parameter_rules=rules,
|
||||
)
|
||||
|
||||
|
||||
@ -138,6 +138,18 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: o3-mini
|
||||
value: o3-mini
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: o3-mini-2025-01-31
|
||||
value: o3-mini-2025-01-31
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: o1-preview
|
||||
value: o1-preview
|
||||
|
||||
@ -123,6 +123,15 @@ provider_credential_schema:
|
||||
en_US: AWS GovCloud (US-West)
|
||||
zh_Hans: AWS GovCloud (US-West)
|
||||
ja_JP: AWS GovCloud (米国西部)
|
||||
- variable: bedrock_endpoint_url
|
||||
label:
|
||||
zh_Hans: Bedrock Endpoint URL
|
||||
en_US: Bedrock Endpoint URL
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 Bedrock Endpoint URL, 如:https://123456.cloudfront.net
|
||||
en_US: Enter your Bedrock Endpoint URL, e.g. https://123456.cloudfront.net
|
||||
- variable: model_for_validation
|
||||
required: false
|
||||
label:
|
||||
|
||||
@ -13,6 +13,7 @@ def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
|
||||
client_config = Config(region_name=region_name)
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
bedrock_endpoint_url = credentials.get("bedrock_endpoint_url")
|
||||
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
# use aksk to call bedrock
|
||||
@ -21,6 +22,7 @@ def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
|
||||
config=client_config,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
**({"endpoint_url": bedrock_endpoint_url} if bedrock_endpoint_url else {}),
|
||||
)
|
||||
else:
|
||||
# use iam without aksk to call
|
||||
|
||||
@ -0,0 +1,115 @@
|
||||
model: us.anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
label:
|
||||
en_US: Claude 3.7 Sonnet(US.Cross Region Inference)
|
||||
icon: icon_s_en.svg
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: enable_cache
|
||||
label:
|
||||
zh_Hans: 启用提示缓存
|
||||
en_US: Enable Prompt Cache
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
help:
|
||||
zh_Hans: 启用提示缓存可以提高性能并降低成本。Claude 3.7 Sonnet支持在system、messages和tools字段中使用缓存检查点。
|
||||
en_US: Enable prompt caching to improve performance and reduce costs. Claude 3.7 Sonnet supports cache checkpoints in system, messages, and tools fields.
|
||||
- name: reasoning_type
|
||||
label:
|
||||
zh_Hans: 推理配置
|
||||
en_US: Reasoning Type
|
||||
type: boolean
|
||||
required: false
|
||||
default: false
|
||||
placeholder:
|
||||
zh_Hans: 设置推理配置
|
||||
en_US: Set reasoning configuration
|
||||
help:
|
||||
zh_Hans: 控制模型的推理能力。启用时,temperature将固定为1且top_p将被禁用。
|
||||
en_US: Controls the model's reasoning capability. When enabled, temperature will be fixed to 1 and top_p will be disabled.
|
||||
- name: reasoning_budget
|
||||
show_on:
|
||||
- variable: reasoning_type
|
||||
value: true
|
||||
label:
|
||||
zh_Hans: 推理预算
|
||||
en_US: Reasoning Budget
|
||||
type: int
|
||||
default: 1024
|
||||
min: 0
|
||||
max: 128000
|
||||
help:
|
||||
zh_Hans: 推理的预算限制(最小1024),必须小于max_tokens。仅在推理类型为enabled时可用。
|
||||
en_US: Budget limit for reasoning (minimum 1024), must be less than max_tokens. Only available when reasoning type is enabled.
|
||||
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
label:
|
||||
zh_Hans: 最大token数
|
||||
en_US: Max Tokens
|
||||
type: int
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 128000
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
label:
|
||||
zh_Hans: 模型温度
|
||||
en_US: Model Temperature
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。当推理功能启用时,该值将被固定为1。
|
||||
en_US: The amount of randomness injected into the response. When reasoning is enabled, this value will be fixed to 1.
|
||||
- name: top_p
|
||||
show_on:
|
||||
- variable: reasoning_type
|
||||
value: disabled
|
||||
use_template: top_p
|
||||
label:
|
||||
zh_Hans: Top P
|
||||
en_US: Top P
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中的概率阈值。当推理功能启用时,该参数将被禁用。
|
||||
en_US: The probability threshold in nucleus sampling. When reasoning is enabled, this parameter will be disabled.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.015'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -58,6 +58,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||
# TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
|
||||
CONVERSE_API_ENABLED_MODEL_INFO = [
|
||||
{"prefix": "anthropic.claude-v2", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "us.deepseek", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "anthropic.claude-v1", "support_system_prompts": True, "support_tool_use": False},
|
||||
{"prefix": "us.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
|
||||
{"prefix": "eu.anthropic.claude-3", "support_system_prompts": True, "support_tool_use": True},
|
||||
|
||||
@ -0,0 +1,63 @@
|
||||
model: us.deepseek.r1-v1:0
|
||||
label:
|
||||
en_US: DeepSeek-R1(US.Cross Region Inference)
|
||||
icon: icon_s_en.svg
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32768
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
label:
|
||||
zh_Hans: 最大token数
|
||||
en_US: Max Tokens
|
||||
type: int
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 128000
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。
|
||||
en_US: The maximum number of tokens to generate before stopping.
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
label:
|
||||
zh_Hans: 模型温度
|
||||
en_US: Model Temperature
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。当推理功能启用时,该值将被固定为1。
|
||||
en_US: The amount of randomness injected into the response. When reasoning is enabled, this value will be fixed to 1.
|
||||
- name: top_p
|
||||
show_on:
|
||||
- variable: reasoning_type
|
||||
value: disabled
|
||||
use_template: top_p
|
||||
label:
|
||||
zh_Hans: Top P
|
||||
en_US: Top P
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中的概率阈值。当推理功能启用时,该参数将被禁用。
|
||||
en_US: The probability threshold in nucleus sampling. When reasoning is enabled, this parameter will be disabled.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
@ -1,13 +1,10 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
@ -39,208 +36,3 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
credentials["stream_function_calling"] = "support"
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ""
|
||||
chunk_index = 0
|
||||
is_reasoning_started = False # Add flag to track reasoning state
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = usage and usage.get("completion_tokens")
|
||||
if completion_tokens is None:
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
id=id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
# delimiter for stream response, need unicode_escape
|
||||
import codecs
|
||||
|
||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_call_id: str):
|
||||
if not tool_call_id:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
finish_reason = None # The default value of finish_reason is None
|
||||
message_id, usage = None, None
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
|
||||
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_json: dict = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
usage=usage,
|
||||
)
|
||||
break
|
||||
# handle the error here. for issue #11629
|
||||
if chunk_json.get("error") and chunk_json.get("choices") is None:
|
||||
raise ValueError(chunk_json.get("error"))
|
||||
|
||||
if chunk_json:
|
||||
if u := chunk_json.get("usage"):
|
||||
usage = u
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
message_id = chunk_json.get("id")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
is_reasoning = delta.get("reasoning_content")
|
||||
delta_content = delta.get("content") or delta.get("reasoning_content")
|
||||
|
||||
assistant_message_tool_calls = None
|
||||
|
||||
if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
|
||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
||||
elif (
|
||||
"function_call" in delta
|
||||
and credentials.get("function_calling_type", "no_call") == "function_call"
|
||||
):
|
||||
assistant_message_tool_calls = [
|
||||
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
|
||||
]
|
||||
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == "":
|
||||
continue
|
||||
|
||||
# Add markdown quote markers for reasoning content
|
||||
if is_reasoning:
|
||||
if not is_reasoning_started:
|
||||
delta_content = "> 💭 " + delta_content
|
||||
is_reasoning_started = True
|
||||
elif "\n\n" in delta_content:
|
||||
delta_content = delta_content.replace("\n\n", "\n> ")
|
||||
elif "\n" in delta_content:
|
||||
delta_content = delta_content.replace("\n", "\n> ")
|
||||
elif is_reasoning_started:
|
||||
# If we were in reasoning mode but now getting regular content,
|
||||
# add \n\n to close the reasoning block
|
||||
delta_content = "\n\n" + delta_content
|
||||
is_reasoning_started = False
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
)
|
||||
|
||||
# reset tool calls
|
||||
tool_calls = []
|
||||
full_assistant_content += delta_content
|
||||
elif "text" in choice:
|
||||
choice_text = choice.get("text", "")
|
||||
if choice_text == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
||||
),
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@ -19,8 +19,8 @@ class GoogleProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `gemini-pro` model for validate,
|
||||
model_instance.validate_credentials(model="gemini-pro", credentials=credentials)
|
||||
# Use `gemini-2.0-flash` model for validate,
|
||||
model_instance.validate_credentials(model="gemini-2.0-flash", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
- gemini-2.0-flash-001
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-pro-exp-02-05
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-2.0-flash-thinking-exp-01-21
|
||||
- gemini-1.5-pro
|
||||
@ -17,5 +19,3 @@
|
||||
- gemini-exp-1206
|
||||
- gemini-exp-1121
|
||||
- gemini-exp-1114
|
||||
- gemini-pro
|
||||
- gemini-pro-vision
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
model: gemini-2.0-flash-001
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash 001
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,41 @@
|
||||
model: gemini-2.0-pro-exp-02-05
|
||||
label:
|
||||
en_US: Gemini 2.0 pro exp 02-05
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,3 +1,4 @@
|
||||
- deepseek-r1-distill-llama-70b
|
||||
- llama-3.1-405b-reasoning
|
||||
- llama-3.3-70b-versatile
|
||||
- llama-3.1-70b-versatile
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
model: deepseek-r1-distill-llama-70b
|
||||
label:
|
||||
en_US: DeepSeek R1 Distill Llama 70b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '3.00'
|
||||
output: '3.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,3 +1,4 @@
|
||||
- deepseek-ai/deepseek-r1
|
||||
- google/gemma-7b
|
||||
- google/codegemma-7b
|
||||
- google/recurrentgemma-2b
|
||||
|
||||
@ -0,0 +1,35 @@
|
||||
model: deepseek-ai/deepseek-r1
|
||||
label:
|
||||
en_US: deepseek-ai/deepseek-r1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0.5
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
min: 0
|
||||
max: 1
|
||||
default: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 1024
|
||||
default: 1024
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
@ -83,7 +83,7 @@ class NVIDIALargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _add_custom_parameters(self, credentials: dict, model: str) -> None:
|
||||
credentials["mode"] = "chat"
|
||||
|
||||
if self.MODEL_SUFFIX_MAP[model]:
|
||||
if self.MODEL_SUFFIX_MAP.get(model):
|
||||
credentials["server_url"] = f"https://ai.api.nvidia.com/v1/{self.MODEL_SUFFIX_MAP[model]}"
|
||||
credentials.pop("endpoint_url")
|
||||
else:
|
||||
|
||||
@ -0,0 +1,52 @@
|
||||
model: cohere.command-r-08-2024
|
||||
label:
|
||||
en_US: cohere.command-r-08-2024 v1.7
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 1.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.0009'
|
||||
output: '0.0009'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -50,3 +50,4 @@ pricing:
|
||||
output: '0.004'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@ -0,0 +1,52 @@
|
||||
model: cohere.command-r-plus-08-2024
|
||||
label:
|
||||
en_US: cohere.command-r-plus-08-2024 v1.6
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 1.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.0156'
|
||||
output: '0.0156'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -50,3 +50,4 @@ pricing:
|
||||
output: '0.0219'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {"modelId": "cohere.command-r-plus", "servingType": "ON_DEMAND"},
|
||||
"servingMode": {"modelId": "cohere.command-r-plus-08-2024", "servingType": "ON_DEMAND"},
|
||||
"chatRequest": {
|
||||
"apiFormat": "COHERE",
|
||||
# "preambleOverride": "You are a helpful assistant.",
|
||||
@ -60,19 +60,19 @@ oci_config_template = {
|
||||
class OCILargeLanguageModel(LargeLanguageModel):
|
||||
# https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
|
||||
_supported_models = {
|
||||
"meta.llama-3-70b-instruct": {
|
||||
"meta.llama-3.1-70b-instruct": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": False,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
"cohere.command-r-16k": {
|
||||
"cohere.command-r-08-2024": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": True,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
"cohere.command-r-plus": {
|
||||
"cohere.command-r-plus-08-2024": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": True,
|
||||
|
||||
@ -49,3 +49,4 @@ pricing:
|
||||
output: '0.015'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
||||
@ -0,0 +1,51 @@
|
||||
model: meta.llama-3.1-70b-instruct
|
||||
label:
|
||||
zh_Hans: meta.llama-3.1-70b-instruct
|
||||
en_US: meta.llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 2.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.0075'
|
||||
output: '0.0075'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -19,8 +19,8 @@ class OCIGENAIProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `cohere.command-r-plus` model for validate,
|
||||
model_instance.validate_credentials(model="cohere.command-r-plus", credentials=credentials)
|
||||
# Use `cohere.command-r-plus-08-2024` model for validate,
|
||||
model_instance.validate_credentials(model="cohere.command-r-plus-08-2024", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
||||
@ -367,6 +367,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = chunk_json["response"]
|
||||
text = self._wrap_thinking_by_tag(text)
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
- gpt-4.1
|
||||
- o1
|
||||
- o1-2024-12-17
|
||||
- o1-mini
|
||||
- o1-mini-2024-09-12
|
||||
- o3-mini
|
||||
- o3-mini-2025-01-31
|
||||
- gpt-4
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
model: gpt-4.1
|
||||
label:
|
||||
zh_Hans: gpt-4.1
|
||||
en_US: gpt-4.1
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1047576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 32768
|
||||
- name: reasoning_effort
|
||||
label:
|
||||
zh_Hans: 推理工作
|
||||
en_US: Reasoning Effort
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 限制推理模型的推理工作
|
||||
en_US: Constrains effort on reasoning for reasoning models
|
||||
required: false
|
||||
options:
|
||||
- low
|
||||
- medium
|
||||
- high
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- json_schema
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '2.00'
|
||||
output: '8.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -619,9 +619,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
# clear illegal prompt messages
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# o1 compatibility
|
||||
# o1, o3 compatibility
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if model.startswith(("o1", "o3")):
|
||||
if "max_tokens" in model_parameters:
|
||||
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
|
||||
del model_parameters["max_tokens"]
|
||||
@ -941,7 +941,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
]
|
||||
)
|
||||
|
||||
if model.startswith("o1"):
|
||||
if model.startswith(("o1", "o3")):
|
||||
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
|
||||
if system_message_count > 0:
|
||||
new_prompt_messages = []
|
||||
@ -1049,26 +1049,29 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if not messages and not tools:
|
||||
return 0
|
||||
|
||||
if model.startswith("ft:"):
|
||||
model = model.split(":")[1]
|
||||
|
||||
# Currently, we can use gpt4o to calculate chatgpt-4o-latest's token.
|
||||
if model == "chatgpt-4o-latest" or model.startswith("o1"):
|
||||
if model == "chatgpt-4o-latest" or model.startswith(("o1", "o3")):
|
||||
model = "gpt-4o"
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(model)
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding_name = "cl100k_base"
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
if model.startswith("gpt-3.5-turbo-0301"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
|
||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4") or model.startswith(("o1", "o3", "o4")):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
|
||||
@ -16,6 +16,19 @@ parameter_rules:
|
||||
default: 50000
|
||||
min: 1
|
||||
max: 50000
|
||||
- name: reasoning_effort
|
||||
label:
|
||||
zh_Hans: 推理工作
|
||||
en_US: reasoning_effort
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 限制推理模型的推理工作
|
||||
en_US: constrains effort on reasoning for reasoning models
|
||||
required: false
|
||||
options:
|
||||
- low
|
||||
- medium
|
||||
- high
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@ -17,6 +17,19 @@ parameter_rules:
|
||||
default: 50000
|
||||
min: 1
|
||||
max: 50000
|
||||
- name: reasoning_effort
|
||||
label:
|
||||
zh_Hans: 推理工作
|
||||
en_US: reasoning_effort
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 限制推理模型的推理工作
|
||||
en_US: constrains effort on reasoning for reasoning models
|
||||
required: false
|
||||
options:
|
||||
- low
|
||||
- medium
|
||||
- high
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
|
||||
@ -0,0 +1,46 @@
|
||||
model: o3-mini-2025-01-31
|
||||
label:
|
||||
zh_Hans: o3-mini-2025-01-31
|
||||
en_US: o3-mini-2025-01-31
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 100000
|
||||
min: 1
|
||||
max: 100000
|
||||
- name: reasoning_effort
|
||||
label:
|
||||
zh_Hans: 推理工作
|
||||
en_US: reasoning_effort
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 限制推理模型的推理工作
|
||||
en_US: constrains effort on reasoning for reasoning models
|
||||
required: false
|
||||
options:
|
||||
- low
|
||||
- medium
|
||||
- high
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1.10'
|
||||
output: '4.40'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,46 @@
|
||||
model: o3-mini
|
||||
label:
|
||||
zh_Hans: o3-mini
|
||||
en_US: o3-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 100000
|
||||
min: 1
|
||||
max: 100000
|
||||
- name: reasoning_effort
|
||||
label:
|
||||
zh_Hans: 推理工作
|
||||
en_US: reasoning_effort
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 限制推理模型的推理工作
|
||||
en_US: constrains effort on reasoning for reasoning models
|
||||
required: false
|
||||
options:
|
||||
- low
|
||||
- medium
|
||||
- high
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '1.10'
|
||||
output: '4.40'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,5 +1,5 @@
|
||||
import codecs
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Union, cast
|
||||
@ -38,8 +38,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
"""
|
||||
@ -99,7 +97,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
|
||||
return self._num_tokens_from_messages(prompt_messages, tools, credentials)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
@ -398,6 +396,73 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _create_final_llm_result_chunk(
|
||||
self,
|
||||
index: int,
|
||||
message: AssistantPromptMessage,
|
||||
finish_reason: str,
|
||||
usage: dict,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
full_content: str,
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = self._num_tokens_from_string(text=prompt_messages[0].content)
|
||||
completion_tokens = usage and usage.get("completion_tokens")
|
||||
if completion_tokens is None:
|
||||
completion_tokens = self._num_tokens_from_string(text=full_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
def _get_tool_call(self, tool_call_id: str, tools_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
"""
|
||||
Get or create a tool call by ID
|
||||
|
||||
:param tool_call_id: tool call ID
|
||||
:param tools_calls: list of existing tool calls
|
||||
:return: existing or new tool call, updated tools_calls
|
||||
"""
|
||||
if not tool_call_id:
|
||||
return tools_calls[-1], tools_calls
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call, tools_calls
|
||||
|
||||
def _increase_tool_call(
|
||||
self, new_tool_calls: list[AssistantPromptMessage.ToolCall], tools_calls: list[AssistantPromptMessage.ToolCall]
|
||||
) -> list[AssistantPromptMessage.ToolCall]:
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call, tools_calls = self._get_tool_call(new_tool_call.function.name, tools_calls)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
return tools_calls
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
@ -410,69 +475,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ""
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = usage and usage.get("completion_tokens")
|
||||
if completion_tokens is None:
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
id=id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
full_assistant_content = ""
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
finish_reason = None
|
||||
usage = None
|
||||
is_reasoning_started = False
|
||||
# delimiter for stream response, need unicode_escape
|
||||
import codecs
|
||||
|
||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_call_id: str):
|
||||
if not tool_call_id:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
finish_reason = None # The default value of finish_reason is None
|
||||
message_id, usage = None, None
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
@ -487,12 +498,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
chunk_json: dict = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
yield self._create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
usage=usage,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
full_content=full_assistant_content,
|
||||
)
|
||||
break
|
||||
# handle the error here. for issue #11629
|
||||
@ -507,12 +521,14 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
message_id = chunk_json.get("id")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
delta_content = delta.get("content")
|
||||
delta_content, is_reasoning_started = self._wrap_thinking_by_reasoning_content(
|
||||
delta, is_reasoning_started
|
||||
)
|
||||
delta_content = self._wrap_thinking_by_tag(delta_content)
|
||||
|
||||
assistant_message_tool_calls = None
|
||||
|
||||
@ -526,12 +542,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
|
||||
]
|
||||
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
tools_calls = self._increase_tool_call(tool_calls, tools_calls)
|
||||
|
||||
if delta_content is None or delta_content == "":
|
||||
continue
|
||||
@ -556,7 +570,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
@ -569,7 +582,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
@ -578,12 +590,15 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
),
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
yield self._create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
full_content=full_assistant_content,
|
||||
)
|
||||
|
||||
def _handle_generate_response(
|
||||
@ -697,12 +712,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(
|
||||
self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
|
||||
self, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Approximate num tokens for model with gpt2 tokenizer.
|
||||
|
||||
:param model: model name
|
||||
:param text: prompt text
|
||||
:param tools: tools for tool calling
|
||||
:return: number of tokens
|
||||
@ -725,7 +739,6 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
credentials: Optional[dict] = None,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
- openai/o1-preview
|
||||
- openai/o1-mini
|
||||
- openai/o3-mini
|
||||
- openai/o3-mini-2025-01-31
|
||||
- openai/gpt-4o
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-4
|
||||
@ -28,5 +30,6 @@
|
||||
- mistralai/mistral-7b-instruct
|
||||
- qwen/qwen-2.5-72b-instruct
|
||||
- qwen/qwen-2-72b-instruct
|
||||
- deepseek/deepseek-r1
|
||||
- deepseek/deepseek-chat
|
||||
- deepseek/deepseek-coder
|
||||
|
||||
@ -53,7 +53,7 @@ parameter_rules:
|
||||
zh_Hans: 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。
|
||||
en_US: A number between -2.0 and 2.0. If the value is positive, new tokens are penalized based on their frequency of occurrence in existing text, reducing the likelihood that the model will repeat the same content.
|
||||
pricing:
|
||||
input: "0.14"
|
||||
output: "0.28"
|
||||
input: "0.49"
|
||||
output: "0.89"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
model: deepseek/deepseek-r1
|
||||
label:
|
||||
en_US: deepseek-r1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 163840
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。
|
||||
en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.01
|
||||
max: 1.00
|
||||
help:
|
||||
zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。
|
||||
en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 0
|
||||
min: -2.0
|
||||
max: 2.0
|
||||
help:
|
||||
zh_Hans: 介于 -2.0 和 2.0 之间的数字。如果该值为正,那么新 token 会根据其在已有文本中的出现频率受到相应的惩罚,降低模型重复相同内容的可能性。
|
||||
en_US: A number between -2.0 and 2.0. If the value is positive, new tokens are penalized based on their frequency of occurrence in existing text, reducing the likelihood that the model will repeat the same content.
|
||||
pricing:
|
||||
input: "3"
|
||||
output: "8"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
@ -0,0 +1,49 @@
|
||||
model: openai/o3-mini-2025-01-31
|
||||
label:
|
||||
en_US: o3-mini-2025-01-31
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 100000
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "1.10"
|
||||
output: "4.40"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
@ -0,0 +1,49 @@
|
||||
model: openai/o3-mini
|
||||
label:
|
||||
en_US: o3-mini
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 100000
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "1.10"
|
||||
output: "4.40"
|
||||
unit: "0.000001"
|
||||
currency: USD
|
||||
@ -12,7 +12,11 @@
|
||||
- Pro/Qwen/Qwen2-VL-7B-Instruct
|
||||
- OpenGVLab/InternVL2-26B
|
||||
- Pro/OpenGVLab/InternVL2-8B
|
||||
- deepseek-ai/DeepSeek-R1
|
||||
- deepseek-ai/DeepSeek-V2-Chat
|
||||
- deepseek-ai/DeepSeek-V2.5
|
||||
- deepseek-ai/DeepSeek-V3
|
||||
- deepseek-ai/DeepSeek-Coder-V2-Instruct
|
||||
- THUDM/glm-4-9b-chat
|
||||
- 01-ai/Yi-1.5-34B-Chat-16K
|
||||
- 01-ai/Yi-1.5-9B-Chat-16K
|
||||
@ -25,3 +29,4 @@
|
||||
- meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
- google/gemma-2-27b-it
|
||||
- google/gemma-2-9b-it
|
||||
- Tencent/Hunyuan-A52B-Instruct
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
model: deepseek-ai/DeepSeek-R1
|
||||
label:
|
||||
zh_Hans: deepseek-ai/DeepSeek-R1
|
||||
en_US: deepseek-ai/DeepSeek-R1
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
min: 1
|
||||
max: 8192
|
||||
default: 4096
|
||||
pricing:
|
||||
input: "4"
|
||||
output: "16"
|
||||
unit: "0.000001"
|
||||
currency: RMB
|
||||
@ -0,0 +1,53 @@
|
||||
model: deepseek-ai/DeepSeek-V3
|
||||
label:
|
||||
en_US: deepseek-ai/DeepSeek-V3
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 64000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
type: int
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: "1"
|
||||
output: "2"
|
||||
unit: "0.000001"
|
||||
currency: RMB
|
||||
@ -1,13 +1,9 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
@ -96,208 +92,3 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_generate_stream_response(
|
||||
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param response: streamed response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
full_assistant_content = ""
|
||||
chunk_index = 0
|
||||
is_reasoning_started = False # Add flag to track reasoning state
|
||||
|
||||
def create_final_llm_result_chunk(
|
||||
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = usage and usage.get("prompt_tokens")
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
|
||||
completion_tokens = usage and usage.get("completion_tokens")
|
||||
if completion_tokens is None:
|
||||
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
id=id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
|
||||
)
|
||||
|
||||
# delimiter for stream response, need unicode_escape
|
||||
import codecs
|
||||
|
||||
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
|
||||
delimiter = codecs.decode(delimiter, "unicode_escape")
|
||||
|
||||
tools_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
|
||||
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
|
||||
def get_tool_call(tool_call_id: str):
|
||||
if not tool_call_id:
|
||||
return tools_calls[-1]
|
||||
|
||||
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
|
||||
if tool_call is None:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
|
||||
)
|
||||
tools_calls.append(tool_call)
|
||||
|
||||
return tool_call
|
||||
|
||||
for new_tool_call in new_tool_calls:
|
||||
# get tool call
|
||||
tool_call = get_tool_call(new_tool_call.function.name)
|
||||
# update tool call
|
||||
if new_tool_call.id:
|
||||
tool_call.id = new_tool_call.id
|
||||
if new_tool_call.type:
|
||||
tool_call.type = new_tool_call.type
|
||||
if new_tool_call.function.name:
|
||||
tool_call.function.name = new_tool_call.function.name
|
||||
if new_tool_call.function.arguments:
|
||||
tool_call.function.arguments += new_tool_call.function.arguments
|
||||
|
||||
finish_reason = None # The default value of finish_reason is None
|
||||
message_id, usage = None, None
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||
chunk = chunk.strip()
|
||||
if chunk:
|
||||
# ignore sse comments
|
||||
if chunk.startswith(":"):
|
||||
continue
|
||||
decoded_chunk = chunk.strip().removeprefix("data:").lstrip()
|
||||
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_json: dict = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered.",
|
||||
usage=usage,
|
||||
)
|
||||
break
|
||||
# handle the error here. for issue #11629
|
||||
if chunk_json.get("error") and chunk_json.get("choices") is None:
|
||||
raise ValueError(chunk_json.get("error"))
|
||||
|
||||
if chunk_json:
|
||||
if u := chunk_json.get("usage"):
|
||||
usage = u
|
||||
if not chunk_json or len(chunk_json["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk_json["choices"][0]
|
||||
finish_reason = chunk_json["choices"][0].get("finish_reason")
|
||||
message_id = chunk_json.get("id")
|
||||
chunk_index += 1
|
||||
|
||||
if "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
delta_content = delta.get("content")
|
||||
|
||||
assistant_message_tool_calls = None
|
||||
|
||||
if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
|
||||
assistant_message_tool_calls = delta.get("tool_calls", None)
|
||||
elif (
|
||||
"function_call" in delta
|
||||
and credentials.get("function_calling_type", "no_call") == "function_call"
|
||||
):
|
||||
assistant_message_tool_calls = [
|
||||
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
|
||||
]
|
||||
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
increase_tool_call(tool_calls)
|
||||
|
||||
if delta_content is None or delta_content == "":
|
||||
continue
|
||||
|
||||
# Check for think tags
|
||||
if "<think>" in delta_content:
|
||||
is_reasoning_started = True
|
||||
# Remove <think> tag and add markdown quote
|
||||
delta_content = "> 💭 " + delta_content.replace("<think>", "")
|
||||
elif "</think>" in delta_content:
|
||||
# Remove </think> tag and add newlines to end quote block
|
||||
delta_content = delta_content.replace("</think>", "") + "\n\n"
|
||||
is_reasoning_started = False
|
||||
elif is_reasoning_started:
|
||||
# Add quote markers for content within thinking block
|
||||
if "\n\n" in delta_content:
|
||||
delta_content = delta_content.replace("\n\n", "\n> ")
|
||||
elif "\n" in delta_content:
|
||||
delta_content = delta_content.replace("\n", "\n> ")
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta_content,
|
||||
)
|
||||
|
||||
# reset tool calls
|
||||
tool_calls = []
|
||||
full_assistant_content += delta_content
|
||||
elif "text" in choice:
|
||||
choice_text = choice.get("text", "")
|
||||
if choice_text == "":
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
|
||||
full_assistant_content += choice_text
|
||||
else:
|
||||
continue
|
||||
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
if tools_calls:
|
||||
yield LLMResultChunk(
|
||||
id=message_id,
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
|
||||
),
|
||||
)
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
id=message_id,
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
model: gemini-2.0-flash-001
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash 001
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,41 @@
|
||||
model: gemini-2.0-flash-lite-preview-02-05
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Lite Preview 0205
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-01-21
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 0121
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-1219
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 1219
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,37 @@
|
||||
model: gemini-2.0-pro-exp-02-05
|
||||
label:
|
||||
en_US: Gemini 2.0 Pro Exp 0205
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2000000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,14 +1,18 @@
|
||||
model: gemini-pro
|
||||
model: gemini-exp-1114
|
||||
label:
|
||||
en_US: Gemini Pro
|
||||
en_US: Gemini exp 1114
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 30720
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@ -23,17 +27,15 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 2048
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 2048
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
@ -1,12 +1,18 @@
|
||||
model: gemini-pro-vision
|
||||
model: gemini-exp-1121
|
||||
label:
|
||||
en_US: Gemini Pro Vision
|
||||
en_US: Gemini exp 1121
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 12288
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
@ -21,15 +27,15 @@ parameter_rules:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens_to_sample
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 4096
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 4096
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
@ -0,0 +1,41 @@
|
||||
model: gemini-exp-1206
|
||||
label:
|
||||
en_US: Gemini exp 1206
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
@ -247,15 +248,34 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
req_params["tools"] = tools
|
||||
|
||||
def _handle_stream_chat_response(chunks: Generator[ChatCompletionChunk]) -> Generator:
|
||||
is_reasoning_started = False
|
||||
for chunk in chunks:
|
||||
content = ""
|
||||
if chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
if is_reasoning_started and not hasattr(delta, "reasoning_content") and not delta.content:
|
||||
content = ""
|
||||
elif hasattr(delta, "reasoning_content"):
|
||||
if not is_reasoning_started:
|
||||
is_reasoning_started = True
|
||||
content = "> 💭 " + delta.reasoning_content
|
||||
else:
|
||||
content = delta.reasoning_content
|
||||
|
||||
if "\n" in content:
|
||||
content = re.sub(r"\n(?!(>|\n))", "\n> ", content)
|
||||
elif is_reasoning_started:
|
||||
content = "\n\n" + delta.content
|
||||
is_reasoning_started = False
|
||||
else:
|
||||
content = delta.content
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk.choices[0].delta.content if chunk.choices else "", tool_calls=[]
|
||||
),
|
||||
message=AssistantPromptMessage(content=content, tool_calls=[]),
|
||||
usage=self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
|
||||
@ -18,6 +18,22 @@ class ModelConfig(BaseModel):
|
||||
|
||||
|
||||
configs: dict[str, ModelConfig] = {
|
||||
"DeepSeek-R1-Distill-Qwen-32B": ModelConfig(
|
||||
properties=ModelProperties(context_size=64000, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"DeepSeek-R1-Distill-Qwen-7B": ModelConfig(
|
||||
properties=ModelProperties(context_size=64000, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"DeepSeek-R1": ModelConfig(
|
||||
properties=ModelProperties(context_size=64000, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT],
|
||||
),
|
||||
"DeepSeek-V3": ModelConfig(
|
||||
properties=ModelProperties(context_size=64000, max_tokens=8192, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.TOOL_CALL, ModelFeature.STREAM_TOOL_CALL],
|
||||
),
|
||||
"Doubao-1.5-vision-pro-32k": ModelConfig(
|
||||
properties=ModelProperties(context_size=32768, max_tokens=12288, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.AGENT_THOUGHT, ModelFeature.VISION],
|
||||
|
||||
@ -118,6 +118,30 @@ model_credential_schema:
|
||||
type: select
|
||||
required: true
|
||||
options:
|
||||
- label:
|
||||
en_US: DeepSeek-R1-Distill-Qwen-32B
|
||||
value: DeepSeek-R1-Distill-Qwen-32B
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: DeepSeek-R1-Distill-Qwen-7B
|
||||
value: DeepSeek-R1-Distill-Qwen-7B
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: DeepSeek-R1
|
||||
value: DeepSeek-R1
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: DeepSeek-V3
|
||||
value: DeepSeek-V3
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-1.5-vision-pro-32k
|
||||
value: Doubao-1.5-vision-pro-32k
|
||||
|
||||
@ -635,16 +635,13 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
handle stream chat generate response
|
||||
"""
|
||||
full_response = ""
|
||||
|
||||
for chunk in resp:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ""):
|
||||
continue
|
||||
|
||||
delta_content = delta.delta.content or ""
|
||||
# check if there is a tool call in the response
|
||||
function_call = None
|
||||
tool_calls = []
|
||||
@ -657,9 +654,10 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
if function_call:
|
||||
assistant_message_tool_calls += [self._extract_response_function_call(function_call)]
|
||||
|
||||
delta_content = self._wrap_thinking_by_tag(delta_content)
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
|
||||
content=delta_content or "", tool_calls=assistant_message_tool_calls
|
||||
)
|
||||
|
||||
if delta.finish_reason is not None:
|
||||
@ -697,7 +695,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
|
||||
),
|
||||
)
|
||||
|
||||
full_response += delta.delta.content
|
||||
full_response += delta_content
|
||||
|
||||
def _handle_completion_generate_response(
|
||||
self,
|
||||
|
||||
@ -77,5 +77,4 @@
|
||||
- onebot
|
||||
- regex
|
||||
- trello
|
||||
- vanna
|
||||
- fal
|
||||
|
||||
@ -1,114 +0,0 @@
|
||||
"""
|
||||
Configuration classes for AWS Bedrock retrieve and generate API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextInferenceConfig:
|
||||
"""Text inference configuration"""
|
||||
|
||||
maxTokens: Optional[int] = None
|
||||
stopSequences: Optional[list[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceConfig:
|
||||
"""Performance configuration"""
|
||||
|
||||
latency: Literal["standard", "optimized"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Prompt template configuration"""
|
||||
|
||||
textPromptTemplate: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuardrailConfig:
|
||||
"""Guardrail configuration"""
|
||||
|
||||
guardrailId: str
|
||||
guardrailVersion: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
"""Generation configuration"""
|
||||
|
||||
additionalModelRequestFields: Optional[dict[str, Any]] = None
|
||||
guardrailConfiguration: Optional[GuardrailConfig] = None
|
||||
inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None
|
||||
performanceConfig: Optional[PerformanceConfig] = None
|
||||
promptTemplate: Optional[PromptTemplate] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorSearchConfig:
|
||||
"""Vector search configuration"""
|
||||
|
||||
filter: Optional[dict[str, Any]] = None
|
||||
numberOfResults: Optional[int] = None
|
||||
overrideSearchType: Optional[Literal["HYBRID", "SEMANTIC"]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalConfig:
|
||||
"""Retrieval configuration"""
|
||||
|
||||
vectorSearchConfiguration: VectorSearchConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestrationConfig:
|
||||
"""Orchestration configuration"""
|
||||
|
||||
additionalModelRequestFields: Optional[dict[str, Any]] = None
|
||||
inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None
|
||||
performanceConfig: Optional[PerformanceConfig] = None
|
||||
promptTemplate: Optional[PromptTemplate] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeBaseConfig:
|
||||
"""Knowledge base configuration"""
|
||||
|
||||
generationConfiguration: GenerationConfig
|
||||
knowledgeBaseId: str
|
||||
modelArn: str
|
||||
orchestrationConfiguration: Optional[OrchestrationConfig] = None
|
||||
retrievalConfiguration: Optional[RetrievalConfig] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConfig:
|
||||
"""Session configuration"""
|
||||
|
||||
kmsKeyArn: Optional[str] = None
|
||||
sessionId: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveAndGenerateConfiguration:
|
||||
"""Retrieve and generate configuration
|
||||
The use of knowledgeBaseConfiguration or externalSourcesConfiguration depends on the type value
|
||||
"""
|
||||
|
||||
type: str = "KNOWLEDGE_BASE"
|
||||
knowledgeBaseConfiguration: Optional[KnowledgeBaseConfig] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrieveAndGenerateConfig:
|
||||
"""Retrieve and generate main configuration"""
|
||||
|
||||
input: dict[str, str]
|
||||
retrieveAndGenerateConfiguration: RetrieveAndGenerateConfiguration
|
||||
sessionConfiguration: Optional[SessionConfig] = None
|
||||
sessionId: Optional[str] = None
|
||||
@ -77,15 +77,27 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
line = 0
|
||||
# Initialize Bedrock client if not already initialized
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime")
|
||||
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||
|
||||
client_kwargs = {"service_name": "bedrock-agent-runtime", "region_name": aws_region or None}
|
||||
|
||||
# Only add credentials if both access key and secret key are provided
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
client_kwargs.update(
|
||||
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||
)
|
||||
|
||||
self.bedrock_client = boto3.client(**client_kwargs)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||
|
||||
try:
|
||||
line = 1
|
||||
if not self.knowledge_base_id:
|
||||
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
|
||||
@ -123,7 +135,14 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 6
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
result_type = tool_parameters.get("result_type")
|
||||
if result_type == "json":
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
else:
|
||||
text = ""
|
||||
for i, res in enumerate(sorted_docs):
|
||||
text += f"{i + 1}: {res['content']}\n"
|
||||
return self.create_text_message(text)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
@ -138,7 +157,6 @@ class BedrockRetrieveTool(BuiltinTool):
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
|
||||
@ -15,6 +15,60 @@ description:
|
||||
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
|
||||
parameters:
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS区域
|
||||
human_description:
|
||||
en_US: AWS region for the Bedrock service
|
||||
zh_Hans: Bedrock服务的AWS区域
|
||||
form: form
|
||||
|
||||
- name: aws_access_key_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Access Key ID
|
||||
zh_Hans: AWS访问密钥ID
|
||||
human_description:
|
||||
en_US: AWS access key ID for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS访问密钥ID(可选)
|
||||
form: form
|
||||
|
||||
- name: aws_secret_access_key
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Secret Access Key
|
||||
zh_Hans: AWS秘密访问密钥
|
||||
human_description:
|
||||
en_US: AWS secret access key for authentication (optional)
|
||||
zh_Hans: 用于身份验证的AWS秘密访问密钥(可选)
|
||||
form: form
|
||||
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: return a list of json or texts
|
||||
zh_Hans: 返回一个列表,内容是json还是纯文本
|
||||
default: text
|
||||
options:
|
||||
- value: json
|
||||
label:
|
||||
en_US: JSON
|
||||
zh_Hans: JSON
|
||||
- value: text
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
form: form
|
||||
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
required: true
|
||||
@ -95,6 +149,7 @@ parameters:
|
||||
zh_Hans: 重拍模型ID
|
||||
pt_BR: rerank model id
|
||||
llm_description: rerank model id
|
||||
default: default
|
||||
options:
|
||||
- value: default
|
||||
label:
|
||||
@ -110,20 +165,6 @@ parameters:
|
||||
zh_Hans: amazon.rerank-v1:0
|
||||
form: form
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
pt_BR: AWS Region
|
||||
human_description:
|
||||
en_US: AWS region where the Bedrock Knowledge Base is located
|
||||
zh_Hans: Bedrock知识库所在的AWS区域
|
||||
pt_BR: AWS region where the Bedrock Knowledge Base is located
|
||||
llm_description: AWS region where the Bedrock Knowledge Base is located
|
||||
form: form
|
||||
|
||||
- name: metadata_filter # Additional parameter for metadata filtering
|
||||
type: string # String type, expects JSON-formatted filter conditions
|
||||
required: false # Optional field - can be omitted
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
|
||||
@ -10,193 +10,63 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||
class BedrockRetrieveAndGenerateTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
|
||||
def _create_text_inference_config(
|
||||
def _invoke(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Create text inference configuration"""
|
||||
if any([max_tokens, stop_sequences, temperature, top_p]):
|
||||
config = {}
|
||||
if max_tokens is not None:
|
||||
config["maxTokens"] = max_tokens
|
||||
if stop_sequences:
|
||||
try:
|
||||
config["stopSequences"] = json.loads(stop_sequences)
|
||||
except json.JSONDecodeError:
|
||||
config["stopSequences"] = []
|
||||
if temperature is not None:
|
||||
config["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
config["topP"] = top_p
|
||||
return config
|
||||
return None
|
||||
|
||||
def _create_guardrail_config(
|
||||
self,
|
||||
guardrail_id: Optional[str] = None,
|
||||
guardrail_version: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""Create guardrail configuration"""
|
||||
if guardrail_id and guardrail_version:
|
||||
return {"guardrailId": guardrail_id, "guardrailVersion": guardrail_version}
|
||||
return None
|
||||
|
||||
def _create_generation_config(
|
||||
self,
|
||||
additional_model_fields: Optional[str] = None,
|
||||
guardrail_config: Optional[dict] = None,
|
||||
text_inference_config: Optional[dict] = None,
|
||||
performance_mode: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create generation configuration"""
|
||||
config = {}
|
||||
|
||||
if additional_model_fields:
|
||||
try:
|
||||
config["additionalModelRequestFields"] = json.loads(additional_model_fields)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if guardrail_config:
|
||||
config["guardrailConfiguration"] = guardrail_config
|
||||
|
||||
if text_inference_config:
|
||||
config["inferenceConfig"] = {"textInferenceConfig": text_inference_config}
|
||||
|
||||
if performance_mode:
|
||||
config["performanceConfig"] = {"latency": performance_mode}
|
||||
|
||||
if prompt_template:
|
||||
config["promptTemplate"] = {"textPromptTemplate": prompt_template}
|
||||
|
||||
return config
|
||||
|
||||
def _create_orchestration_config(
|
||||
self,
|
||||
orchestration_additional_model_fields: Optional[str] = None,
|
||||
orchestration_text_inference_config: Optional[dict] = None,
|
||||
orchestration_performance_mode: Optional[str] = None,
|
||||
orchestration_prompt_template: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Create orchestration configuration"""
|
||||
config = {}
|
||||
|
||||
if orchestration_additional_model_fields:
|
||||
try:
|
||||
config["additionalModelRequestFields"] = json.loads(orchestration_additional_model_fields)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if orchestration_text_inference_config:
|
||||
config["inferenceConfig"] = {"textInferenceConfig": orchestration_text_inference_config}
|
||||
|
||||
if orchestration_performance_mode:
|
||||
config["performanceConfig"] = {"latency": orchestration_performance_mode}
|
||||
|
||||
if orchestration_prompt_template:
|
||||
config["promptTemplate"] = {"textPromptTemplate": orchestration_prompt_template}
|
||||
|
||||
return config
|
||||
|
||||
def _create_vector_search_config(
|
||||
self,
|
||||
number_of_results: int = 5,
|
||||
search_type: str = "SEMANTIC",
|
||||
metadata_filter: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Create vector search configuration"""
|
||||
config = {
|
||||
"numberOfResults": number_of_results,
|
||||
"overrideSearchType": search_type,
|
||||
}
|
||||
|
||||
# Only add filter if metadata_filter is not empty
|
||||
if metadata_filter:
|
||||
config["filter"] = metadata_filter
|
||||
|
||||
return config
|
||||
|
||||
def _bedrock_retrieve_and_generate(
|
||||
self,
|
||||
query: str,
|
||||
knowledge_base_id: str,
|
||||
model_arn: str,
|
||||
# Generation Configuration
|
||||
additional_model_fields: Optional[str] = None,
|
||||
guardrail_id: Optional[str] = None,
|
||||
guardrail_version: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
stop_sequences: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
performance_mode: str = "standard",
|
||||
prompt_template: Optional[str] = None,
|
||||
# Orchestration Configuration
|
||||
orchestration_additional_model_fields: Optional[str] = None,
|
||||
orchestration_max_tokens: Optional[int] = None,
|
||||
orchestration_stop_sequences: Optional[str] = None,
|
||||
orchestration_temperature: Optional[float] = None,
|
||||
orchestration_top_p: Optional[float] = None,
|
||||
orchestration_performance_mode: Optional[str] = None,
|
||||
orchestration_prompt_template: Optional[str] = None,
|
||||
# Retrieval Configuration
|
||||
number_of_results: int = 5,
|
||||
search_type: str = "SEMANTIC",
|
||||
metadata_filter: Optional[dict] = None,
|
||||
# Additional Configuration
|
||||
session_id: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> ToolInvokeMessage:
|
||||
try:
|
||||
# Create text inference configurations
|
||||
text_inference_config = self._create_text_inference_config(max_tokens, stop_sequences, temperature, top_p)
|
||||
orchestration_text_inference_config = self._create_text_inference_config(
|
||||
orchestration_max_tokens, orchestration_stop_sequences, orchestration_temperature, orchestration_top_p
|
||||
)
|
||||
# Initialize Bedrock client if not already initialized
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||
|
||||
# Create guardrail configuration
|
||||
guardrail_config = self._create_guardrail_config(guardrail_id, guardrail_version)
|
||||
client_kwargs = {"service_name": "bedrock-agent-runtime", "region_name": aws_region or None}
|
||||
|
||||
# Create vector search configuration
|
||||
vector_search_config = self._create_vector_search_config(number_of_results, search_type, metadata_filter)
|
||||
# Only add credentials if both access key and secret key are provided
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
client_kwargs.update(
|
||||
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||
)
|
||||
|
||||
# Create generation configuration
|
||||
generation_config = self._create_generation_config(
|
||||
additional_model_fields, guardrail_config, text_inference_config, performance_mode, prompt_template
|
||||
)
|
||||
self.bedrock_client = boto3.client(**client_kwargs)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||
|
||||
# Create orchestration configuration
|
||||
orchestration_config = self._create_orchestration_config(
|
||||
orchestration_additional_model_fields,
|
||||
orchestration_text_inference_config,
|
||||
orchestration_performance_mode,
|
||||
orchestration_prompt_template,
|
||||
)
|
||||
try:
|
||||
request_config = {}
|
||||
|
||||
# Create knowledge base configuration
|
||||
knowledge_base_config = {
|
||||
"knowledgeBaseId": knowledge_base_id,
|
||||
"modelArn": model_arn,
|
||||
"generationConfiguration": generation_config,
|
||||
"orchestrationConfiguration": orchestration_config,
|
||||
"retrievalConfiguration": {"vectorSearchConfiguration": vector_search_config},
|
||||
}
|
||||
# Set input configuration
|
||||
input_text = tool_parameters.get("input")
|
||||
if input_text:
|
||||
request_config["input"] = {"text": input_text}
|
||||
|
||||
# Create request configuration
|
||||
request_config = {
|
||||
"input": {"text": query},
|
||||
"retrieveAndGenerateConfiguration": {
|
||||
"type": "KNOWLEDGE_BASE",
|
||||
"knowledgeBaseConfiguration": knowledge_base_config,
|
||||
},
|
||||
}
|
||||
# Build retrieve and generate configuration
|
||||
config_type = tool_parameters.get("type")
|
||||
retrieve_generate_config = {"type": config_type}
|
||||
|
||||
# Add session configuration if provided
|
||||
if session_id and len(session_id) >= 2:
|
||||
request_config["sessionConfiguration"] = {"sessionId": session_id}
|
||||
# Add configuration based on type
|
||||
if config_type == "KNOWLEDGE_BASE":
|
||||
kb_config_str = tool_parameters.get("knowledge_base_configuration")
|
||||
kb_config = json.loads(kb_config_str) if kb_config_str else None
|
||||
retrieve_generate_config["knowledgeBaseConfiguration"] = kb_config
|
||||
else: # EXTERNAL_SOURCES
|
||||
es_config_str = tool_parameters.get("external_sources_configuration")
|
||||
es_config = json.loads(kb_config_str) if es_config_str else None
|
||||
retrieve_generate_config["externalSourcesConfiguration"] = es_config
|
||||
|
||||
request_config["retrieveAndGenerateConfiguration"] = retrieve_generate_config
|
||||
|
||||
# Parse session configuration
|
||||
session_config_str = tool_parameters.get("session_configuration")
|
||||
session_config = json.loads(session_config_str) if session_config_str else None
|
||||
if session_config:
|
||||
request_config["sessionConfiguration"] = session_config
|
||||
|
||||
# Add session ID if provided
|
||||
session_id = tool_parameters.get("session_id")
|
||||
if session_id:
|
||||
request_config["sessionId"] = session_id
|
||||
|
||||
# Send request
|
||||
@ -226,99 +96,42 @@ class BedrockRetrieveAndGenerateTool(BuiltinTool):
|
||||
citation_info["references"].append(reference)
|
||||
|
||||
result["citations"].append(citation_info)
|
||||
|
||||
return result
|
||||
|
||||
result_type = tool_parameters.get("result_type")
|
||||
if result_type == "json":
|
||||
return self.create_json_message(result)
|
||||
elif result_type == "text-with-citations":
|
||||
return self.create_text_message(result)
|
||||
else:
|
||||
return self.create_text_message(result.get("output"))
|
||||
except json.JSONDecodeError as e:
|
||||
return self.create_text_message(f"Invalid JSON format: {str(e)}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error calling Bedrock service: {str(e)}")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> ToolInvokeMessage:
|
||||
try:
|
||||
# Initialize Bedrock client if not already initialized
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||
|
||||
client_kwargs = {
|
||||
"service_name": "bedrock-agent-runtime",
|
||||
}
|
||||
if aws_region:
|
||||
client_kwargs["region_name"] = aws_region
|
||||
# Only add credentials if both access key and secret key are provided
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
client_kwargs.update(
|
||||
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||
)
|
||||
|
||||
try:
|
||||
self.bedrock_client = boto3.client(**client_kwargs)
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||
|
||||
# Parse metadata filter if provided
|
||||
metadata_filter = None
|
||||
if metadata_filter_str := tool_parameters.get("metadata_filter"):
|
||||
try:
|
||||
parsed_filter = json.loads(metadata_filter_str)
|
||||
if parsed_filter: # Only set if not empty
|
||||
metadata_filter = parsed_filter
|
||||
except json.JSONDecodeError:
|
||||
return self.create_text_message("metadata_filter must be a valid JSON string")
|
||||
|
||||
try:
|
||||
response = self._bedrock_retrieve_and_generate(
|
||||
query=tool_parameters["query"],
|
||||
knowledge_base_id=tool_parameters["knowledge_base_id"],
|
||||
model_arn=tool_parameters["model_arn"],
|
||||
# Generation Configuration
|
||||
additional_model_fields=tool_parameters.get("additional_model_fields"),
|
||||
guardrail_id=tool_parameters.get("guardrail_id"),
|
||||
guardrail_version=tool_parameters.get("guardrail_version"),
|
||||
max_tokens=tool_parameters.get("max_tokens"),
|
||||
stop_sequences=tool_parameters.get("stop_sequences"),
|
||||
temperature=tool_parameters.get("temperature"),
|
||||
top_p=tool_parameters.get("top_p"),
|
||||
performance_mode=tool_parameters.get("performance_mode", "standard"),
|
||||
prompt_template=tool_parameters.get("prompt_template"),
|
||||
# Orchestration Configuration
|
||||
orchestration_additional_model_fields=tool_parameters.get("orchestration_additional_model_fields"),
|
||||
orchestration_max_tokens=tool_parameters.get("orchestration_max_tokens"),
|
||||
orchestration_stop_sequences=tool_parameters.get("orchestration_stop_sequences"),
|
||||
orchestration_temperature=tool_parameters.get("orchestration_temperature"),
|
||||
orchestration_top_p=tool_parameters.get("orchestration_top_p"),
|
||||
orchestration_performance_mode=tool_parameters.get("orchestration_performance_mode"),
|
||||
orchestration_prompt_template=tool_parameters.get("orchestration_prompt_template"),
|
||||
# Retrieval Configuration
|
||||
number_of_results=tool_parameters.get("number_of_results", 5),
|
||||
search_type=tool_parameters.get("search_type", "SEMANTIC"),
|
||||
metadata_filter=metadata_filter,
|
||||
# Additional Configuration
|
||||
session_id=tool_parameters.get("session_id"),
|
||||
)
|
||||
return self.create_json_message(response)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Tool invocation error: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Tool execution error: {str(e)}")
|
||||
return self.create_text_message(f"Tool invocation error: {str(e)}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""Validate the parameters"""
|
||||
required_params = ["query", "model_arn", "knowledge_base_id"]
|
||||
for param in required_params:
|
||||
if not parameters.get(param):
|
||||
raise ValueError(f"{param} is required")
|
||||
# Validate required parameters
|
||||
if not parameters.get("input"):
|
||||
raise ValueError("input is required")
|
||||
if not parameters.get("type"):
|
||||
raise ValueError("type is required")
|
||||
|
||||
# Validate metadata filter if provided
|
||||
if metadata_filter_str := parameters.get("metadata_filter"):
|
||||
try:
|
||||
if not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("metadata_filter must be a valid JSON string")
|
||||
# Validate JSON configurations
|
||||
json_configs = ["knowledge_base_configuration", "external_sources_configuration", "session_configuration"]
|
||||
for config in json_configs:
|
||||
if config_value := parameters.get(config):
|
||||
try:
|
||||
json.loads(config_value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{config} must be a valid JSON string")
|
||||
|
||||
# Validate configuration type
|
||||
config_type = parameters.get("type")
|
||||
if config_type not in ["KNOWLEDGE_BASE", "EXTERNAL_SOURCES"]:
|
||||
raise ValueError("type must be either KNOWLEDGE_BASE or EXTERNAL_SOURCES")
|
||||
|
||||
# Validate type-specific configuration
|
||||
if config_type == "KNOWLEDGE_BASE" and not parameters.get("knowledge_base_configuration"):
|
||||
raise ValueError("knowledge_base_configuration is required when type is KNOWLEDGE_BASE")
|
||||
elif config_type == "EXTERNAL_SOURCES" and not parameters.get("external_sources_configuration"):
|
||||
raise ValueError("external_sources_configuration is required when type is EXTERNAL_SOURCES")
|
||||
|
||||
@ -8,24 +8,11 @@ identity:
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving and generating information using Amazon Bedrock Knowledge Base
|
||||
zh_Hans: 使用Amazon Bedrock知识库进行信息检索和生成的工具
|
||||
en_US: "This is an advanced usage of Bedrock Retrieve. Please refer to the API documentation for detailed parameters and paste them into the corresponding Knowledge Base Configuration or External Sources Configuration"
|
||||
zh_Hans: "这个工具为Bedrock Retrieve的高级用法,请参考API设置详细的参数,并粘贴到对应的知识库配置或者外部源配置"
|
||||
llm: A tool for retrieving and generating information using Amazon Bedrock Knowledge Base
|
||||
|
||||
parameters:
|
||||
# Additional Configuration
|
||||
- name: session_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Session ID
|
||||
zh_Hans: 会话ID
|
||||
human_description:
|
||||
en_US: Optional session ID for continuous conversations
|
||||
zh_Hans: 用于连续对话的可选会话ID
|
||||
form: form
|
||||
|
||||
# AWS Configuration
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
@ -59,300 +46,103 @@ parameters:
|
||||
zh_Hans: 用于身份验证的AWS秘密访问密钥(可选)
|
||||
form: form
|
||||
|
||||
# Knowledge Base Configuration
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
- name: result_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: Knowledge Base ID
|
||||
zh_Hans: 知识库ID
|
||||
en_US: result type
|
||||
zh_Hans: 结果类型
|
||||
human_description:
|
||||
en_US: ID of the Bedrock Knowledge Base
|
||||
zh_Hans: Bedrock知识库的ID
|
||||
en_US: return a list of json or texts
|
||||
zh_Hans: 返回一个列表,内容是json还是纯文本
|
||||
default: text
|
||||
options:
|
||||
- value: json
|
||||
label:
|
||||
en_US: JSON
|
||||
zh_Hans: JSON
|
||||
- value: text
|
||||
label:
|
||||
en_US: Text
|
||||
zh_Hans: 文本
|
||||
- value: text-with-citations
|
||||
label:
|
||||
en_US: Text With Citations
|
||||
zh_Hans: 文本(包含引用)
|
||||
form: form
|
||||
|
||||
- name: model_arn
|
||||
- name: input
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Model ARN
|
||||
zh_Hans: 模型ARN
|
||||
en_US: Input Text
|
||||
zh_Hans: 输入文本
|
||||
human_description:
|
||||
en_US: The ARN of the model to use
|
||||
zh_Hans: 要使用的模型ARN
|
||||
form: form
|
||||
|
||||
# Retrieval Configuration
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query
|
||||
zh_Hans: 查询
|
||||
human_description:
|
||||
en_US: The search query to retrieve information
|
||||
zh_Hans: 用于检索信息的查询语句
|
||||
en_US: The text query to retrieve information
|
||||
zh_Hans: 用于检索信息的文本查询
|
||||
form: llm
|
||||
|
||||
- name: number_of_results
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Number of Results
|
||||
zh_Hans: 结果数量
|
||||
human_description:
|
||||
en_US: Number of results to retrieve (1-10)
|
||||
zh_Hans: 要检索的结果数量(1-10)
|
||||
default: 5
|
||||
min: 1
|
||||
max: 10
|
||||
form: form
|
||||
|
||||
- name: search_type
|
||||
- name: type
|
||||
type: select
|
||||
required: false
|
||||
required: true
|
||||
label:
|
||||
en_US: Search Type
|
||||
zh_Hans: 搜索类型
|
||||
en_US: Configuration Type
|
||||
zh_Hans: 配置类型
|
||||
human_description:
|
||||
en_US: Type of search to perform
|
||||
zh_Hans: 要执行的搜索类型
|
||||
default: SEMANTIC
|
||||
en_US: Type of retrieve and generate configuration
|
||||
zh_Hans: 检索和生成配置的类型
|
||||
options:
|
||||
- value: SEMANTIC
|
||||
- value: KNOWLEDGE_BASE
|
||||
label:
|
||||
en_US: Semantic Search
|
||||
zh_Hans: 语义搜索
|
||||
- value: HYBRID
|
||||
en_US: Knowledge Base
|
||||
zh_Hans: 知识库
|
||||
- value: EXTERNAL_SOURCES
|
||||
label:
|
||||
en_US: Hybrid Search
|
||||
zh_Hans: 混合搜索
|
||||
en_US: External Sources
|
||||
zh_Hans: 外部源
|
||||
form: form
|
||||
|
||||
- name: metadata_filter
|
||||
- name: knowledge_base_configuration
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Metadata Filter
|
||||
zh_Hans: 元数据过滤器
|
||||
en_US: Knowledge Base Configuration
|
||||
zh_Hans: 知识库配置
|
||||
human_description:
|
||||
en_US: JSON formatted filter conditions for metadata, supporting operations like equals, greaterThan, lessThan, etc.
|
||||
zh_Hans: 元数据的JSON格式过滤条件,支持等于、大于、小于等操作
|
||||
default: "{}"
|
||||
en_US: Please refer to @https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate for complete parameters and paste them here
|
||||
zh_Hans: 请参考 https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate 配置完整的参数并粘贴到这里
|
||||
form: form
|
||||
|
||||
# Generation Configuration
|
||||
- name: guardrail_id
|
||||
- name: external_sources_configuration
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Guardrail ID
|
||||
zh_Hans: 防护栏ID
|
||||
en_US: External Sources Configuration
|
||||
zh_Hans: 外部源配置
|
||||
human_description:
|
||||
en_US: ID of the guardrail to apply
|
||||
zh_Hans: 要应用的防护栏ID
|
||||
en_US: Please refer to https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate for complete parameters and paste them here
|
||||
zh_Hans: 请参考 https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html#retrieve-and-generate 配置完整的参数并粘贴到这里
|
||||
form: form
|
||||
|
||||
- name: guardrail_version
|
||||
- name: session_configuration
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Guardrail Version
|
||||
zh_Hans: 防护栏版本
|
||||
en_US: Session Configuration
|
||||
zh_Hans: 会话配置
|
||||
human_description:
|
||||
en_US: Version of the guardrail to apply
|
||||
zh_Hans: 要应用的防护栏版本
|
||||
en_US: JSON formatted session configuration
|
||||
zh_Hans: JSON格式的会话配置
|
||||
default: ""
|
||||
form: form
|
||||
|
||||
- name: max_tokens
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Maximum Tokens
|
||||
zh_Hans: 最大令牌数
|
||||
human_description:
|
||||
en_US: Maximum number of tokens to generate
|
||||
zh_Hans: 生成的最大令牌数
|
||||
default: 2048
|
||||
form: form
|
||||
|
||||
- name: stop_sequences
|
||||
- name: session_id
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Stop Sequences
|
||||
zh_Hans: 停止序列
|
||||
en_US: Session ID
|
||||
zh_Hans: 会话ID
|
||||
human_description:
|
||||
en_US: JSON array of strings that will stop generation when encountered
|
||||
zh_Hans: JSON数组格式的字符串,遇到这些序列时将停止生成
|
||||
default: "[]"
|
||||
form: form
|
||||
|
||||
- name: temperature
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Temperature
|
||||
zh_Hans: 温度
|
||||
human_description:
|
||||
en_US: Controls randomness in the output (0-1)
|
||||
zh_Hans: 控制输出的随机性(0-1)
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: top_p
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Top P
|
||||
zh_Hans: Top P值
|
||||
human_description:
|
||||
en_US: Controls diversity via nucleus sampling (0-1)
|
||||
zh_Hans: 通过核采样控制多样性(0-1)
|
||||
default: 0.95
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: performance_mode
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Performance Mode
|
||||
zh_Hans: 性能模式
|
||||
human_description:
|
||||
en_US: Select performance optimization mode(performanceConfig.latency)
|
||||
zh_Hans: 选择性能优化模式(performanceConfig.latency)
|
||||
default: standard
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
- value: optimized
|
||||
label:
|
||||
en_US: Optimized
|
||||
zh_Hans: 优化
|
||||
form: form
|
||||
|
||||
- name: prompt_template
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Prompt Template
|
||||
zh_Hans: 提示模板
|
||||
human_description:
|
||||
en_US: Custom prompt template for generation
|
||||
zh_Hans: 用于生成的自定义提示模板
|
||||
form: form
|
||||
|
||||
- name: additional_model_fields
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Additional Model Fields
|
||||
zh_Hans: 额外模型字段
|
||||
human_description:
|
||||
en_US: JSON formatted additional fields for model configuration
|
||||
zh_Hans: JSON格式的额外模型配置字段
|
||||
default: "{}"
|
||||
form: form
|
||||
|
||||
# Orchestration Configuration
|
||||
- name: orchestration_max_tokens
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Maximum Tokens
|
||||
zh_Hans: 编排最大令牌数
|
||||
human_description:
|
||||
en_US: Maximum number of tokens for orchestration
|
||||
zh_Hans: 编排过程的最大令牌数
|
||||
default: 2048
|
||||
form: form
|
||||
|
||||
- name: orchestration_stop_sequences
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Stop Sequences
|
||||
zh_Hans: 编排停止序列
|
||||
human_description:
|
||||
en_US: JSON array of strings that will stop orchestration when encountered
|
||||
zh_Hans: JSON数组格式的字符串,遇到这些序列时将停止编排
|
||||
default: "[]"
|
||||
form: form
|
||||
|
||||
- name: orchestration_temperature
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Temperature
|
||||
zh_Hans: 编排温度
|
||||
human_description:
|
||||
en_US: Controls randomness in the orchestration output (0-1)
|
||||
zh_Hans: 控制编排输出的随机性(0-1)
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: orchestration_top_p
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Top P
|
||||
zh_Hans: 编排Top P值
|
||||
human_description:
|
||||
en_US: Controls diversity via nucleus sampling in orchestration (0-1)
|
||||
zh_Hans: 通过核采样控制编排的多样性(0-1)
|
||||
default: 0.95
|
||||
min: 0
|
||||
max: 1
|
||||
form: form
|
||||
|
||||
- name: orchestration_performance_mode
|
||||
type: select
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Performance Mode
|
||||
zh_Hans: 编排性能模式
|
||||
human_description:
|
||||
en_US: Select performance optimization mode for orchestration
|
||||
zh_Hans: 选择编排的性能优化模式
|
||||
default: standard
|
||||
options:
|
||||
- value: standard
|
||||
label:
|
||||
en_US: Standard
|
||||
zh_Hans: 标准
|
||||
- value: optimized
|
||||
label:
|
||||
en_US: Optimized
|
||||
zh_Hans: 优化
|
||||
form: form
|
||||
|
||||
- name: orchestration_prompt_template
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Prompt Template
|
||||
zh_Hans: 编排提示模板
|
||||
human_description:
|
||||
en_US: Custom prompt template for orchestration
|
||||
zh_Hans: 用于编排的自定义提示模板
|
||||
form: form
|
||||
|
||||
- name: orchestration_additional_model_fields
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Orchestration Additional Model Fields
|
||||
zh_Hans: 编排额外模型字段
|
||||
human_description:
|
||||
en_US: JSON formatted additional fields for orchestration model configuration
|
||||
zh_Hans: JSON格式的编排模型额外配置字段
|
||||
default: "{}"
|
||||
en_US: Session ID for continuous conversations
|
||||
zh_Hans: 用于连续对话的会话ID
|
||||
form: form
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 4.5 KiB |
@ -1,134 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from vanna.remote import VannaDefault # type: ignore
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class VannaTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
# Ensure runtime and credentials
|
||||
if not self.runtime or not self.runtime.credentials:
|
||||
raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing")
|
||||
api_key = self.runtime.credentials.get("api_key", None)
|
||||
if not api_key:
|
||||
raise ToolProviderCredentialValidationError("Please input api key")
|
||||
|
||||
model = tool_parameters.get("model", "")
|
||||
if not model:
|
||||
return self.create_text_message("Please input RAG model")
|
||||
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
if not prompt:
|
||||
return self.create_text_message("Please input prompt")
|
||||
|
||||
url = tool_parameters.get("url", "")
|
||||
if not url:
|
||||
return self.create_text_message("Please input URL/Host/DSN")
|
||||
|
||||
db_name = tool_parameters.get("db_name", "")
|
||||
username = tool_parameters.get("username", "")
|
||||
password = tool_parameters.get("password", "")
|
||||
port = tool_parameters.get("port", 0)
|
||||
|
||||
base_url = self.runtime.credentials.get("base_url", None)
|
||||
vn = VannaDefault(model=model, api_key=api_key, config={"endpoint": base_url})
|
||||
|
||||
db_type = tool_parameters.get("db_type", "")
|
||||
if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}:
|
||||
if not db_name:
|
||||
return self.create_text_message("Please input database name")
|
||||
if not username:
|
||||
return self.create_text_message("Please input username")
|
||||
if port < 1:
|
||||
return self.create_text_message("Please input port")
|
||||
|
||||
schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS"
|
||||
match db_type:
|
||||
case "SQLite":
|
||||
schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null"
|
||||
vn.connect_to_sqlite(url)
|
||||
case "Postgres":
|
||||
vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||
case "DuckDB":
|
||||
vn.connect_to_duckdb(url=url)
|
||||
case "SQLServer":
|
||||
vn.connect_to_mssql(url)
|
||||
case "MySQL":
|
||||
vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||
case "Oracle":
|
||||
vn.connect_to_oracle(user=username, password=password, dsn=url)
|
||||
case "Hive":
|
||||
vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||
case "ClickHouse":
|
||||
vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||
|
||||
enable_training = tool_parameters.get("enable_training", False)
|
||||
reset_training_data = tool_parameters.get("reset_training_data", False)
|
||||
if enable_training:
|
||||
if reset_training_data:
|
||||
existing_training_data = vn.get_training_data()
|
||||
if len(existing_training_data) > 0:
|
||||
for _, training_data in existing_training_data.iterrows():
|
||||
vn.remove_training_data(training_data["id"])
|
||||
|
||||
ddl = tool_parameters.get("ddl", "")
|
||||
question = tool_parameters.get("question", "")
|
||||
sql = tool_parameters.get("sql", "")
|
||||
memos = tool_parameters.get("memos", "")
|
||||
training_metadata = tool_parameters.get("training_metadata", False)
|
||||
|
||||
if training_metadata:
|
||||
if db_type == "SQLite":
|
||||
df_ddl = vn.run_sql(schema_sql)
|
||||
for ddl in df_ddl["sql"].to_list():
|
||||
vn.train(ddl=ddl)
|
||||
else:
|
||||
df_information_schema = vn.run_sql(schema_sql)
|
||||
plan = vn.get_training_plan_generic(df_information_schema)
|
||||
vn.train(plan=plan)
|
||||
|
||||
if ddl:
|
||||
vn.train(ddl=ddl)
|
||||
|
||||
if sql:
|
||||
if question:
|
||||
vn.train(question=question, sql=sql)
|
||||
else:
|
||||
vn.train(sql=sql)
|
||||
if memos:
|
||||
vn.train(documentation=memos)
|
||||
|
||||
#########################################################################################
|
||||
# Due to CVE-2024-5565, we have to disable the chart generation feature
|
||||
# The Vanna library uses a prompt function to present the user with visualized results,
|
||||
# it is possible to alter the prompt using prompt injection and run arbitrary Python code
|
||||
# instead of the intended visualization code.
|
||||
# Specifically - allowing external input to the library’s “ask” method
|
||||
# with "visualize" set to True (default behavior) leads to remote code execution.
|
||||
# Affected versions: <= 0.5.5
|
||||
#########################################################################################
|
||||
allow_llm_to_see_data = tool_parameters.get("allow_llm_to_see_data", False)
|
||||
res = vn.ask(
|
||||
prompt, print_results=False, auto_train=True, visualize=False, allow_llm_to_see_data=allow_llm_to_see_data
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
if res is not None:
|
||||
result.append(self.create_text_message(res[0]))
|
||||
if len(res) > 1 and res[1] is not None:
|
||||
result.append(self.create_text_message(res[1].to_markdown()))
|
||||
if len(res) > 2 and res[2] is not None:
|
||||
result.append(
|
||||
self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"})
|
||||
)
|
||||
|
||||
return result
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user