mirror of
https://github.com/langgenius/dify.git
synced 2026-01-30 00:36:12 +08:00
Compare commits
93 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8835435558 | |||
| a80d8286c2 | |||
| 6b15827246 | |||
| 41d0a8b295 | |||
| d0e1ea8f06 | |||
| f3b9647bb4 | |||
| 9de67c586f | |||
| 92f594f5e7 | |||
| 06d5273217 | |||
| 94d7babbf1 | |||
| 306216dbe5 | |||
| ab2e20ee0a | |||
| 146e95d88f | |||
| d7ae86799c | |||
| 7b26c9e2ef | |||
| 6bcafdbc87 | |||
| 059c089f93 | |||
| c1e7193c4b | |||
| 2423563d45 | |||
| 260672986e | |||
| 5d48406d64 | |||
| 2b2dbabc11 | |||
| 13b64bc55a | |||
| 279f099ba0 | |||
| 32747641e4 | |||
| db43ed6f41 | |||
| 7699621983 | |||
| 4dfbcd0b4e | |||
| a9ee18300e | |||
| b4861d2b5c | |||
| 913f2b84a6 | |||
| cc89933d8f | |||
| a14ea6582d | |||
| 076f3289d2 | |||
| 518083dfe0 | |||
| 2b366bb321 | |||
| 292d4c077a | |||
| fc4c03640d | |||
| 985253197f | |||
| 48b4249790 | |||
| fb64fcb271 | |||
| 41e452dcc5 | |||
| d218c66e25 | |||
| e173b1cb2a | |||
| 9b598db559 | |||
| e122d677ad | |||
| 4c63cbf5b1 | |||
| 288705fefd | |||
| 8c4ae98f3d | |||
| 08aa367892 | |||
| ff527a0190 | |||
| 6e05f8ca93 | |||
| 6309d070d1 | |||
| fe14130b3c | |||
| 52ebffa857 | |||
| d14f15863d | |||
| 7c9b585a47 | |||
| c039f4af83 | |||
| 07285e5f8b | |||
| 16d80ebab3 | |||
| 61e816f24c | |||
| 2feb16d957 | |||
| 3043fbe73b | |||
| 9f99c3f55b | |||
| a07a6d8c26 | |||
| 695841a3cf | |||
| 3efaa713da | |||
| 9822f687f7 | |||
| b9d83c04bc | |||
| 298ad6782d | |||
| f4be2b8bcd | |||
| e83e239faf | |||
| 62bf7f0fc2 | |||
| 7dea485d57 | |||
| 5b9858a8a3 | |||
| 42a5b3ec17 | |||
| 2d1cb076c6 | |||
| 289c93d081 | |||
| c0fe706597 | |||
| 9cba1c8bf4 | |||
| cbf095465c | |||
| c007dbdc13 | |||
| ff493d017b | |||
| 7f6ad9653e | |||
| 2851a9f04e | |||
| c536f85b2e | |||
| b1352ff8b7 | |||
| cc63c8499f | |||
| f191b8b8d1 | |||
| 5003db987d | |||
| 07aab5e868 | |||
| 875dfbbf0e | |||
| 9e7efa45d4 |
2
.github/workflows/build-api-image.yml
vendored
2
.github/workflows/build-api-image.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
with:
|
||||
images: langgenius/dify-api
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}}
|
||||
|
||||
2
.github/workflows/build-web-image.yml
vendored
2
.github/workflows/build-web-image.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
with:
|
||||
images: langgenius/dify-web
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}}
|
||||
|
||||
37
.github/workflows/check_no_chinese_comments.py
vendored
37
.github/workflows/check_no_chinese_comments.py
vendored
@ -1,37 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
from zhon.hanzi import punctuation
|
||||
|
||||
def has_chinese_characters(text):
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff' or char in punctuation:
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_file_for_chinese_comments(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
for line_number, line in enumerate(file, start=1):
|
||||
if has_chinese_characters(line):
|
||||
print(f"Found Chinese characters in {file_path} on line {line_number}:")
|
||||
print(line.strip())
|
||||
return True
|
||||
return False
|
||||
|
||||
def main():
|
||||
has_chinese = False
|
||||
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
|
||||
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
|
||||
'prompts.py']
|
||||
|
||||
for root, _, files in os.walk("."):
|
||||
for file in files:
|
||||
if file.endswith(".py") and file not in excluded_files:
|
||||
file_path = os.path.join(root, file)
|
||||
if check_file_for_chinese_comments(file_path):
|
||||
has_chinese = True
|
||||
|
||||
if has_chinese:
|
||||
raise Exception("Found Chinese characters in Python files. Please remove them.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
31
.github/workflows/check_no_chinese_comments.yml
vendored
31
.github/workflows/check_no_chinese_comments.yml
vendored
@ -1,31 +0,0 @@
|
||||
name: Check for Chinese comments
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
check-chinese-comments:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.9
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install zhon
|
||||
|
||||
- name: Run script to check for Chinese comments
|
||||
run: |
|
||||
python .github/workflows/check_no_chinese_comments.py
|
||||
@ -37,7 +37,6 @@ https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e19
|
||||
|
||||
|
||||
We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)):
|
||||
* 600,000 free Claude model tokens to build Claude-powered apps
|
||||
* 200 free OpenAI queries to build OpenAI-based apps
|
||||
|
||||
|
||||
|
||||
@ -36,7 +36,6 @@ https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e19
|
||||
|
||||
|
||||
我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用):
|
||||
* 60 万 Tokens Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
|
||||
* 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用
|
||||
* 300 万 讯飞星火大模型 Token 的调用额度,用于创建基于讯飞星火大模型的 AI 应用
|
||||
* 100 万 MiniMax Token 的调用额度,用于创建基于 MiniMax 模型的 AI 应用
|
||||
|
||||
@ -18,6 +18,9 @@ SERVICE_API_URL=http://127.0.0.1:5001
|
||||
APP_API_URL=http://127.0.0.1:5001
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Files URL
|
||||
FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
@ -50,7 +53,7 @@ S3_REGION=your-region
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@ -63,6 +66,21 @@ WEAVIATE_BATCH_SIZE=100
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_API_KEY=difyai123456
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_HOST=127.0.0.1
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
"request": "launch",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "api/app.py",
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_DEBUG": "1",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
13
api/app.py
13
api/app.py
@ -6,6 +6,9 @@ from werkzeug.exceptions import Unauthorized
|
||||
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
if os.environ.get("VECTOR_STORE") == 'milvus':
|
||||
import grpc.experimental.gevent
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import logging
|
||||
import json
|
||||
@ -16,7 +19,7 @@ from flask_cors import CORS
|
||||
|
||||
from core.model_providers.providers import hosted
|
||||
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
|
||||
ext_database, ext_storage, ext_mail, ext_stripe
|
||||
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
|
||||
@ -76,6 +79,7 @@ def create_app(test_config=None) -> Flask:
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init(app, db)
|
||||
ext_redis.init_app(app)
|
||||
@ -122,6 +126,7 @@ def register_blueprints(app):
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
|
||||
CORS(service_api_bp,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
@ -151,6 +156,12 @@ def register_blueprints(app):
|
||||
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(files_bp,
|
||||
allow_headers=['Content-Type'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
|
||||
219
api/config.py
219
api/config.py
@ -26,6 +26,7 @@ DEFAULTS = {
|
||||
'SERVICE_API_URL': 'https://api.dify.ai',
|
||||
'APP_WEB_URL': 'https://udify.app',
|
||||
'APP_API_URL': 'https://udify.app',
|
||||
'FILES_URL': '',
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
@ -57,6 +58,9 @@ DEFAULTS = {
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
||||
'UPLOAD_FILE_BATCH_LIMIT': 5,
|
||||
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
|
||||
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
|
||||
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64'
|
||||
}
|
||||
|
||||
|
||||
@ -83,78 +87,65 @@ class Config:
|
||||
"""Application configuration class."""
|
||||
|
||||
def __init__(self):
|
||||
# app settings
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.25"
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.3.30"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
|
||||
# The backend URL prefix of the console API.
|
||||
# used to concatenate the login authorization callback or notion integration callback.
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
|
||||
# The front-end URL prefix of the console web.
|
||||
# used to concatenate some front-end addresses and for CORS configuration use.
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
|
||||
# WebApp API backend Url prefix.
|
||||
# used to declare the back-end URL for the front-end API.
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
|
||||
# WebApp Url prefix.
|
||||
# used to display WebAPP API Base Url to the front-end.
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
|
||||
# Service API Url prefix.
|
||||
# used to display Service API Base Url to the front-end.
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
|
||||
# File preview or download Url prefix.
|
||||
# used to display File preview or download Url to the front-end or as Multi-model inputs;
|
||||
# Url is signed and has expiration time.
|
||||
self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
|
||||
|
||||
# Fallback Url prefix.
|
||||
# Will be deprecated in the future.
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
|
||||
# Your App secret key will be used for securely signing the session cookie
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
# You can generate a strong key using `openssl rand -base64 42`.
|
||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||
self.SECRET_KEY = get_env('SECRET_KEY')
|
||||
|
||||
# redis settings
|
||||
self.REDIS_HOST = get_env('REDIS_HOST')
|
||||
self.REDIS_PORT = get_env('REDIS_PORT')
|
||||
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
|
||||
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
|
||||
self.REDIS_DB = get_env('REDIS_DB')
|
||||
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
||||
|
||||
# storage settings
|
||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
|
||||
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
|
||||
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
|
||||
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
|
||||
self.S3_REGION = get_env('S3_REGION')
|
||||
|
||||
# vector store settings, only support weaviate, qdrant
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
# weaviate settings
|
||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
|
||||
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
|
||||
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
|
||||
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'WEB_API_CORS_ALLOW_ORIGINS', '*')
|
||||
|
||||
# mail settings
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# sentry settings
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
|
||||
|
||||
# check update url
|
||||
self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL')
|
||||
|
||||
# database settings
|
||||
# ------------------------
|
||||
# Database Configurations.
|
||||
# ------------------------
|
||||
db_credentials = {
|
||||
key: get_env(key) for key in
|
||||
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
|
||||
@ -168,14 +159,102 @@ class Config:
|
||||
|
||||
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
|
||||
|
||||
# celery settings
|
||||
# ------------------------
|
||||
# Redis Configurations.
|
||||
# ------------------------
|
||||
self.REDIS_HOST = get_env('REDIS_HOST')
|
||||
self.REDIS_PORT = get_env('REDIS_PORT')
|
||||
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
|
||||
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
|
||||
self.REDIS_DB = get_env('REDIS_DB')
|
||||
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
||||
|
||||
# ------------------------
|
||||
# Celery worker Configurations.
|
||||
# ------------------------
|
||||
self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
|
||||
self.CELERY_BACKEND = get_env('CELERY_BACKEND')
|
||||
self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
|
||||
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
|
||||
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
|
||||
|
||||
# hosted provider credentials
|
||||
# ------------------------
|
||||
# File Storage Configurations.
|
||||
# ------------------------
|
||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
|
||||
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
|
||||
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
|
||||
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
|
||||
self.S3_REGION = get_env('S3_REGION')
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
|
||||
# milvus / zilliz setting
|
||||
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||
self.MILVUS_PORT = get_env('MILVUS_PORT')
|
||||
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||
|
||||
# weaviate settings
|
||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
|
||||
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
|
||||
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
|
||||
|
||||
# ------------------------
|
||||
# Mail Configurations.
|
||||
# ------------------------
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# ------------------------
|
||||
# Sentry Configurations.
|
||||
# ------------------------
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
|
||||
|
||||
# ------------------------
|
||||
# Business Configurations.
|
||||
# ------------------------
|
||||
|
||||
# multi model send image format, support base64, url, default is base64
|
||||
self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
|
||||
|
||||
# Dataset Configurations.
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
# File upload Configurations.
|
||||
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
|
||||
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
|
||||
self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
|
||||
|
||||
# Moderation in app Configurations.
|
||||
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
|
||||
|
||||
# Notion integration setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
# ------------------------
|
||||
# Platform Configurations.
|
||||
# ------------------------
|
||||
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
|
||||
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
@ -203,23 +282,6 @@ class Config:
|
||||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
# notion import setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
# uploading settings
|
||||
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
|
||||
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
@ -234,18 +296,5 @@ class CloudEditionConfig(Config):
|
||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
|
||||
class TestConfig(Config):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.TESTING = True
|
||||
|
||||
db_credentials = {
|
||||
key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT']
|
||||
}
|
||||
|
||||
# use a different database for testing: dify_test
|
||||
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/dify_test"
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
@ -31,6 +31,7 @@ model_templates = {
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
@ -81,6 +82,7 @@ model_templates = {
|
||||
'model': json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 512,
|
||||
"temperature": 1,
|
||||
@ -137,10 +139,11 @@ demo_model_templates = {
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n",
|
||||
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
@ -169,6 +172,13 @@ demo_model_templates = {
|
||||
'Italian',
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "Query",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
@ -200,6 +210,7 @@ demo_model_templates = {
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
@ -255,10 +266,11 @@ demo_model_templates = {
|
||||
},
|
||||
opening_statement='',
|
||||
suggested_questions=None,
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
|
||||
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo-instruct",
|
||||
"mode": "completion",
|
||||
"completion_params": {
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0,
|
||||
@ -287,6 +299,13 @@ demo_model_templates = {
|
||||
"意大利语",
|
||||
]
|
||||
}
|
||||
},{
|
||||
"paragraph": {
|
||||
"label": "文本内容",
|
||||
"variable": "query",
|
||||
"required": True,
|
||||
"default": ""
|
||||
}
|
||||
}
|
||||
])
|
||||
)
|
||||
@ -318,6 +337,7 @@ demo_model_templates = {
|
||||
model=json.dumps({
|
||||
"provider": "openai",
|
||||
"name": "gpt-3.5-turbo",
|
||||
"mode": "chat",
|
||||
"completion_params": {
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.8,
|
||||
|
||||
@ -6,10 +6,10 @@ bp = Blueprint('console', __name__, url_prefix='/console/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
# Import other controllers
|
||||
from . import setup, version, apikey, admin
|
||||
from . import extension, setup, version, apikey, admin
|
||||
|
||||
# Import app controllers
|
||||
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import login, oauth, data_source_oauth, activate
|
||||
|
||||
25
api/controllers/console/app/advanced_prompt_template.py
Normal file
25
api/controllers/console/app/advanced_prompt_template.py
Normal file
@ -0,0 +1,25 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('app_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('model_mode', type=str, required=True, location='args')
|
||||
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
|
||||
parser.add_argument('model_name', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
|
||||
@ -40,12 +40,14 @@ class CompletionMessageApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
@ -113,6 +115,7 @@ class ChatMessageApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
@ -120,6 +123,7 @@ class ChatMessageApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
|
||||
@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'completion')
|
||||
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -230,7 +230,7 @@ class ChatConversationDetailApi(Resource):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'chat')
|
||||
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -253,8 +253,6 @@ class ChatConversationDetailApi(Resource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
|
||||
|
||||
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
|
||||
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
|
||||
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
|
||||
|
||||
@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
|
||||
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
|
||||
|
||||
|
||||
class IntroductionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('prompt_template', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
|
||||
try:
|
||||
answer = LLMGenerator.generate_introduction(
|
||||
account.current_tenant_id,
|
||||
args['prompt_template']
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, LLMAuthorizationError) as e:
|
||||
raise CompletionRequestError(str(e))
|
||||
|
||||
return {'introduction': answer}
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
|
||||
return rules
|
||||
|
||||
|
||||
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
|
||||
api.add_resource(RuleGenerateApi, '/rule-generate')
|
||||
|
||||
@ -295,8 +295,8 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
check_enabled=False
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
@ -329,7 +329,7 @@ class MessageApi(Resource):
|
||||
message_id = str(message_id)
|
||||
|
||||
# get app info
|
||||
app_model = _get_app(app_id, 'chat')
|
||||
app_model = _get_app(app_id)
|
||||
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
@ -20,8 +19,6 @@ from models.source import DataSourceBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
import services
|
||||
from libs.login import login_required
|
||||
@ -15,9 +15,6 @@ from fields.file_fields import upload_config_fields, file_fields
|
||||
|
||||
from services.file_service import FileService
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
@ -30,9 +27,11 @@ class FileApi(Resource):
|
||||
def get(self):
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
|
||||
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
|
||||
image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT")
|
||||
return {
|
||||
'file_size_limit': file_size_limit,
|
||||
'batch_count_limit': batch_count_limit
|
||||
'batch_count_limit': batch_count_limit,
|
||||
'image_file_size_limit': image_file_size_limit
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@ -51,7 +50,7 @@ class FileApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file)
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Generator, Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
@ -17,6 +18,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
@ -32,11 +34,16 @@ class CompletionApi(InstalledAppResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
installed_app.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
@ -91,12 +98,17 @@ class ChatApi(InstalledAppResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
installed_app.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
|
||||
@ -38,7 +38,8 @@ class ConversationListApi(InstalledAppResource):
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
pinned=pinned,
|
||||
exclude_debug_conversation=True
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
@ -71,11 +72,18 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -39,8 +39,9 @@ class InstalledAppsListApi(Resource):
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
]
|
||||
installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at']
|
||||
if app['last_used_at'] is not None else datetime.min))
|
||||
installed_apps.sort(key=lambda app: (-app['is_pinned'],
|
||||
app['last_used_at'] is None,
|
||||
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
|
||||
|
||||
return {'installed_apps': installed_apps}
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, fields
|
||||
from flask import current_app
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
@ -19,6 +20,10 @@ class AppParameterApi(InstalledAppResource):
|
||||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
@ -27,6 +32,9 @@ class AppParameterApi(InstalledAppResource):
|
||||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@ -42,7 +50,12 @@ class AppParameterApi(InstalledAppResource):
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
'file_upload': app_model_config.file_upload_dict,
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
@ -19,6 +20,7 @@ message_fields = {
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
114
api/controllers/console/extension.py
Normal file
114
api/controllers/console/extension.py
Normal file
@ -0,0 +1,114 @@
|
||||
from flask_restful import Resource, reqparse, marshal_with
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.login import login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from fields.api_based_extension_fields import api_based_extension_fields
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
|
||||
|
||||
class CodeBasedExtensionAPI(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('module', type=str, required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
return {
|
||||
'module': args['module'],
|
||||
'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
|
||||
}
|
||||
|
||||
|
||||
class APIBasedExtensionAPI(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('api_endpoint', type=str, required=True, location='json')
|
||||
parser.add_argument('api_key', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
name=args['name'],
|
||||
api_endpoint=args['api_endpoint'],
|
||||
api_key=args['api_key']
|
||||
)
|
||||
|
||||
return APIBasedExtensionService.save(extension_data)
|
||||
|
||||
|
||||
class APIBasedExtensionDetailAPI(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def get(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(api_based_extension_fields)
|
||||
def post(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('api_endpoint', type=str, required=True, location='json')
|
||||
parser.add_argument('api_key', type=str, required=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
extension_data_from_db.name = args['name']
|
||||
extension_data_from_db.api_endpoint = args['api_endpoint']
|
||||
|
||||
if args['api_key'] != '[__HIDDEN__]':
|
||||
extension_data_from_db.api_key = args['api_key']
|
||||
|
||||
return APIBasedExtensionService.save(extension_data_from_db)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
|
||||
APIBasedExtensionService.delete(extension_data_from_db)
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
|
||||
|
||||
api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
|
||||
api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')
|
||||
@ -25,6 +25,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
@ -60,6 +61,8 @@ class UniversalChatApi(UniversalChatResource):
|
||||
del args['model']
|
||||
del args['tools']
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
|
||||
@ -65,11 +65,18 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
10
api/controllers/files/__init__.py
Normal file
10
api/controllers/files/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('files', __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import image_preview
|
||||
40
api/controllers/files/image_preview.py
Normal file
40
api/controllers/files/image_preview.py
Normal file
@ -0,0 +1,40 @@
|
||||
from flask import request, Response
|
||||
from flask_restful import Resource
|
||||
|
||||
import services
|
||||
from controllers.files import api
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class ImagePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
sign = request.args.get('sign')
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
return {'content': 'Invalid request.'}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_image_preview(
|
||||
file_id,
|
||||
timestamp,
|
||||
nonce,
|
||||
sign
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from .app import completion, app, conversation, message, audio
|
||||
from .app import completion, app, conversation, message, audio, file
|
||||
|
||||
from .dataset import document, segment, dataset
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask import current_app
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
@ -20,6 +21,10 @@ class AppParameterApi(AppApiResource):
|
||||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
@ -28,6 +33,9 @@ class AppParameterApi(AppApiResource):
|
||||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@ -42,7 +50,12 @@ class AppParameterApi(AppApiResource):
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
'file_upload': app_model_config.file_upload_dict,
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ class CompletionApi(AppApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
@ -39,13 +40,15 @@ class CompletionApi(AppApiResource):
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
streaming=streaming
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
@ -90,10 +93,12 @@ class ChatApi(AppApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -183,4 +188,3 @@ api.add_resource(CompletionApi, '/completion-messages')
|
||||
api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
|
||||
api.add_resource(ChatApi, '/chat-messages')
|
||||
api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')
|
||||
|
||||
|
||||
@ -54,6 +54,7 @@ class ConversationDetailApi(AppApiResource):
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
|
||||
@marshal_with(simple_conversation_fields)
|
||||
@ -64,15 +65,22 @@ class ConversationRenameApi(AppApiResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
end_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -75,3 +75,26 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
42
api/controllers/service_api/app/file.py
Normal file
42
api/controllers/service_api/app/file.py
Normal file
@ -0,0 +1,42 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
UnsupportedFileTypeError
|
||||
import services
|
||||
from services.file_service import FileService
|
||||
from fields.file_fields import file_fields
|
||||
|
||||
|
||||
class FileApi(AppApiResource):
|
||||
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
|
||||
file = request.files['file']
|
||||
user_args = request.form.get('user')
|
||||
|
||||
if end_user is None and user_args is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, end_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
@ -10,7 +10,9 @@ from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from services.message_service import MessageService
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, EndUser
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
class MessageListApi(AppApiResource):
|
||||
feedback_fields = {
|
||||
@ -41,6 +43,7 @@ class MessageListApi(AppApiResource):
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
@ -96,5 +99,38 @@ class MessageFeedbackApi(AppApiResource):
|
||||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageSuggestedApi(AppApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
message_id = str(message_id)
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
try:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
).first()
|
||||
|
||||
if end_user is None and message.from_end_user_id is not None:
|
||||
user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.id == message.from_end_user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
else:
|
||||
user = end_user
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id,
|
||||
check_enabled=False
|
||||
)
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
return {'result': 'success', 'data': questions}
|
||||
|
||||
|
||||
api.add_resource(MessageListApi, '/messages')
|
||||
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
|
||||
api.add_resource(MessageSuggestedApi, '/messages/<uuid:message_id>/suggested')
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse, marshal
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -173,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
@ -235,7 +236,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
@ -9,8 +8,8 @@ from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestE
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import segment_fields
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DocumentService, SegmentService
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
|
||||
|
||||
class SegmentApi(DatasetApiResource):
|
||||
@ -24,6 +23,8 @@ class SegmentApi(DatasetApiResource):
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
@ -55,5 +56,146 @@ class SegmentApi(DatasetApiResource):
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('status', type=str,
|
||||
action='append', default=[], location='args')
|
||||
parser.add_argument('keyword', type=str, default=None, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
status_list = args['status']
|
||||
keyword = args['keyword']
|
||||
|
||||
query = DocumentSegment.query.filter(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
)
|
||||
|
||||
if status_list:
|
||||
query = query.filter(DocumentSegment.status.in_(status_list))
|
||||
|
||||
if keyword:
|
||||
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
|
||||
|
||||
total = query.count()
|
||||
segments = query.order_by(DocumentSegment.position).all()
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form,
|
||||
'total': total
|
||||
}, 200
|
||||
|
||||
|
||||
class DatasetSegmentApi(DatasetApiResource):
|
||||
def delete(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check segment
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id
|
||||
).first()
|
||||
if not segment:
|
||||
raise NotFound('Segment not found.')
|
||||
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
SegmentService.segment_create_args_validate(args['segments'], document)
|
||||
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
|
||||
return {
|
||||
'data': marshal(segment, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
||||
api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
|
||||
|
||||
@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport, file
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, fields
|
||||
from flask import current_app
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
@ -19,6 +20,10 @@ class AppParameterApi(WebApiResource):
|
||||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
@ -27,6 +32,9 @@ class AppParameterApi(WebApiResource):
|
||||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@ -41,7 +49,12 @@ class AppParameterApi(WebApiResource):
|
||||
'speech_to_text': app_model_config.speech_to_text_dict,
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
'file_upload': app_model_config.file_upload_dict,
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -30,12 +30,14 @@ class CompletionApi(WebApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
@ -88,6 +90,7 @@ class ChatApi(WebApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
||||
@ -95,6 +98,7 @@ class ChatApi(WebApiResource):
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
@ -139,7 +143,7 @@ class ChatStopApi(WebApiResource):
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
def compact_response(response: Union[dict | Generator]) -> Response:
|
||||
def compact_response(response: Union[dict, Generator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
||||
else:
|
||||
|
||||
@ -67,11 +67,18 @@ class ConversationRenameApi(WebApiResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
end_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -85,4 +85,28 @@ class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
36
api/controllers/web/file.py
Normal file
36
api/controllers/web/file.py
Normal file
@ -0,0 +1,36 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from controllers.web.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
UnsupportedFileTypeError
|
||||
import services
|
||||
from services.file_service import FileService
|
||||
from fields.file_fields import file_fields
|
||||
|
||||
|
||||
class FileApi(WebApiResource):
|
||||
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, end_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
@ -22,6 +22,7 @@ from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
|
||||
class MessageListApi(WebApiResource):
|
||||
@ -54,6 +55,7 @@ class MessageListApi(WebApiResource):
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
@ -115,7 +117,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
try:
|
||||
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
|
||||
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
|
||||
return compact_response(response)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@ -8,6 +8,8 @@ from controllers.web.wraps import WebApiResource
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
@ -18,6 +20,7 @@ message_fields = {
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
@ -0,0 +1 @@
|
||||
import core.moderation.base
|
||||
@ -2,14 +2,18 @@ import json
|
||||
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
@ -65,17 +73,57 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
try:
|
||||
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
|
||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||
if isinstance(agent_decision, AgentAction):
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
def real_plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
}
|
||||
)
|
||||
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
return agent_decision
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
@ -87,7 +135,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
@ -96,11 +144,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
|
||||
_format_intermediate_steps
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
|
||||
get_buffer_string
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_instance: BaseLLM = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_instance=model_instance,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 40
|
||||
original_max_tokens = self.model_instance.model_kwargs.max_tokens
|
||||
self.model_instance.model_kwargs.max_tokens = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
callbacks=None
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
function_call = result.function_call
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
self.model_instance.model_kwargs.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
functions=self.functions,
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.content,
|
||||
additional_kwargs={
|
||||
'function_call': result.function_call
|
||||
}
|
||||
)
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model_instance.model_provider.provider_name == 'azure_openai':
|
||||
model = model_instance.base_model_name
|
||||
model = model.replace("gpt-35", "gpt-3.5")
|
||||
else:
|
||||
model = model_instance.base_model_name
|
||||
|
||||
tiktoken_ = _import_tiktoken()
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
except KeyError:
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
|
||||
@ -1,140 +0,0 @@
|
||||
from typing import cast, List
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.chat_models.openai import _convert_message_to_dict
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
|
||||
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
llm = cast(ChatOpenAI, model_instance.client)
|
||||
model, encoding = llm._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
@ -1,107 +0,0 @@
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
|
||||
from langchain.agents import BaseMultiActionAgent
|
||||
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
|
||||
_parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, SystemMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
|
||||
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
|
||||
|
||||
|
||||
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
**kwargs: Any,
|
||||
) -> BaseMultiActionAgent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=cls.get_system_message(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = self.llm.max_tokens
|
||||
self.llm.max_tokens = 15
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=None
|
||||
)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
function_call = predicted_message.additional_kwargs.get("function_call", {})
|
||||
|
||||
self.llm.max_tokens = original_max_tokens
|
||||
|
||||
return True if function_call else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
# get current time
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
@ -1,10 +1,9 @@
|
||||
import re
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
@ -12,6 +11,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
@ -49,7 +50,6 @@ Action:
|
||||
|
||||
|
||||
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
model_instance: BaseLLM
|
||||
dataset_tools: Sequence[BaseTool]
|
||||
|
||||
class Config:
|
||||
@ -93,12 +93,16 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||
|
||||
if intermediate_steps:
|
||||
_, observation = intermediate_steps[-1]
|
||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
@ -108,6 +112,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
elif isinstance(tool_inputs, str):
|
||||
agent_decision.tool_input = kwargs['input']
|
||||
else:
|
||||
agent_decision.return_values['output'] = ''
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
@ -142,10 +150,65 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def create_completion_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
agent_scratchpad += action.log
|
||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
@ -157,17 +220,36 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_instance.model_mode == ModelMode.CHAT:
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
else:
|
||||
prompt = cls.create_completion_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
dataset_tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1,19 +1,21 @@
|
||||
import re
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional
|
||||
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
|
||||
|
||||
from langchain import BasePromptTemplate
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
|
||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
|
||||
get_buffer_string
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
@ -52,8 +54,7 @@ Action:
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_llm: BaseLanguageModel = None
|
||||
model_instance: BaseLLM
|
||||
summary_model_instance: BaseLLM = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -95,14 +96,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
new_exception = self.model_instance.handle_exceptions(e)
|
||||
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
|
||||
raise new_exception
|
||||
|
||||
try:
|
||||
@ -118,7 +119,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_llm:
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_instance:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
@ -130,11 +131,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
summary_handler = SummarizerMixin(llm=self.summary_llm)
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = summary_handler.predict_new_summary(
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
@ -144,6 +144,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: List[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
@ -173,10 +185,65 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def create_completion_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
agent_scratchpad += action.log
|
||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
model_instance: BaseLLM,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
@ -188,16 +255,35 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
return super().from_llm_and_tools(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_instance.model_mode == ModelMode.CHAT:
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
else:
|
||||
prompt = cls.create_completion_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_instance=model_instance,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
output_parser=output_parser,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum):
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
MULTI_FUNCTION_CALL = 'multi_function_call'
|
||||
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
@ -64,30 +62,18 @@ class AgentExecutor:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_llm=self.configuration.summary_model_instance.client
|
||||
summary_model_instance=self.configuration.summary_model_instance
|
||||
if self.configuration.summary_model_instance else None,
|
||||
verbose=True
|
||||
)
|
||||
@ -95,7 +81,6 @@ class AgentExecutor:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
|
||||
verbose=True
|
||||
@ -104,7 +89,6 @@ class AgentExecutor:
|
||||
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
|
||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
||||
model_instance=self.configuration.model_instance,
|
||||
llm=self.configuration.model_instance.client,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
verbose=True
|
||||
|
||||
@ -1,13 +1,26 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Union
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult, BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
|
||||
ImagePromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.moderation.base import ModerationOutputsResult, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class LLMCallbackHandler(BaseCallbackHandler):
|
||||
@ -20,6 +33,24 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
self.start_at = None
|
||||
self.conversation_message_task = conversation_message_task
|
||||
|
||||
self.output_moderation_handler = None
|
||||
self.init_output_moderation()
|
||||
|
||||
def init_output_moderation(self):
|
||||
app_model_config = self.conversation_message_task.app_model_config
|
||||
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
|
||||
|
||||
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
|
||||
self.output_moderation_handler = OutputModerationHandler(
|
||||
tenant_id=self.conversation_message_task.tenant_id,
|
||||
app_id=self.conversation_message_task.app.id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config")
|
||||
),
|
||||
on_message_replace_func=self.conversation_message_task.on_message_replace
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
@ -42,7 +73,12 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
real_prompts.append({
|
||||
"role": role,
|
||||
"text": message.content
|
||||
"text": message.content,
|
||||
"files": [{
|
||||
"type": file.type.value,
|
||||
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
|
||||
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
|
||||
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
@ -59,10 +95,19 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(response.generations[0][0].text)
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
|
||||
completion=response.generations[0][0].text,
|
||||
public_event=True if self.conversation_message_task.streaming else False
|
||||
)
|
||||
else:
|
||||
self.llm_message.completion = response.generations[0][0].text
|
||||
|
||||
if not self.conversation_message_task.streaming:
|
||||
self.conversation_message_task.append_message_text(self.llm_message.completion)
|
||||
|
||||
if response.llm_output and 'token_usage' in response.llm_output:
|
||||
if 'prompt_tokens' in response.llm_output['token_usage']:
|
||||
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
|
||||
@ -79,23 +124,161 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
self.conversation_message_task.save_message(self.llm_message)
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
try:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
except ConversationTaskStoppedException as ex:
|
||||
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
ex = ConversationTaskInterruptException()
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
self.llm_message.completion += token
|
||||
try:
|
||||
self.conversation_message_task.append_message_text(token)
|
||||
self.llm_message.completion += token
|
||||
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.append_new_token(token)
|
||||
except ConversationTaskStoppedException as ex:
|
||||
self.on_llm_error(error=ex)
|
||||
raise ex
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
if self.output_moderation_handler:
|
||||
self.output_moderation_handler.stop_thread()
|
||||
|
||||
if isinstance(error, ConversationTaskStoppedException):
|
||||
if self.conversation_message_task.streaming:
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
|
||||
if isinstance(error, ConversationTaskInterruptException):
|
||||
self.llm_message.completion = self.output_moderation_handler.get_final_output()
|
||||
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
|
||||
[PromptMessage(content=self.llm_message.completion)]
|
||||
)
|
||||
self.conversation_message_task.save_message(llm_message=self.llm_message)
|
||||
else:
|
||||
logging.debug("on_llm_error: %s", error)
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
DEFAULT_BUFFER_SIZE: int = 300
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
on_message_replace_func: Any
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ''
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
self.thread = self.start_thread()
|
||||
|
||||
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=completion
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
else:
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||
})
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
def stop_thread(self):
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread_running = False
|
||||
|
||||
def worker(self, flask_app: Flask, buffer_size: int):
|
||||
with flask_app.app_context():
|
||||
current_length = 0
|
||||
while self.thread_running:
|
||||
moderation_buffer = self.buffer
|
||||
buffer_length = len(moderation_buffer)
|
||||
if not self.is_final_chunk:
|
||||
chunk_length = buffer_length - current_length
|
||||
if 0 <= chunk_length < buffer_size:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
continue
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
36
api/core/chain/llm_chain.py
Normal file
36
api/core/chain/llm_chain.py
Normal file
@ -0,0 +1,36 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from langchain import LLMChain as LCLLMChain
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema import LLMResult, Generation
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from core.model_providers.models.entity.message import to_prompt_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class LLMChain(LCLLMChain):
|
||||
model_instance: BaseLLM
|
||||
"""The language model instance to use."""
|
||||
llm: BaseLanguageModel = FakeLLM(response="")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
messages = prompts[0].to_messages()
|
||||
prompt_messages = to_prompt_messages(messages)
|
||||
result = self.model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop
|
||||
)
|
||||
|
||||
generations = [
|
||||
[Generation(text=result.content)]
|
||||
]
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
@ -1,92 +0,0 @@
|
||||
import enum
|
||||
import logging
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.moderation import openai_moderation
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceRule(BaseModel):
|
||||
class Type(enum.Enum):
|
||||
MODERATION = "moderation"
|
||||
KEYWORDS = "keywords"
|
||||
|
||||
type: Type
|
||||
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
|
||||
extra_params: dict = {}
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceChain(Chain):
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
model_instance: BaseLLM
|
||||
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "sensitive_word_avoidance_chain"
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _check_sensitive_word(self, text: str) -> bool:
|
||||
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
|
||||
if word in text:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_moderation(self, text: str) -> bool:
|
||||
moderation_model_instance = ModelFactory.get_moderation_model(
|
||||
tenant_id=self.model_instance.model_provider.provider.tenant_id,
|
||||
model_provider_name='openai',
|
||||
model_name=openai_moderation.DEFAULT_MODEL
|
||||
)
|
||||
|
||||
try:
|
||||
return moderation_model_instance.run(text=text)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
text = inputs[self.input_key]
|
||||
|
||||
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
|
||||
result = self._check_sensitive_word(text)
|
||||
else:
|
||||
result = self._check_moderation(text)
|
||||
|
||||
if not result:
|
||||
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
|
||||
|
||||
return {self.output_key: text}
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceError(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
@ -1,36 +1,43 @@
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Union
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional, List, Union, Tuple
|
||||
|
||||
from flask import current_app, Flask
|
||||
from requests.exceptions import ChunkedEncodingError
|
||||
|
||||
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
||||
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
||||
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
from models.dataset import DocumentSegment, Dataset, Document
|
||||
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.model import App, AppModelConfig, Account, Conversation, EndUser
|
||||
from core.moderation.base import ModerationException, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class Completion:
|
||||
@classmethod
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
|
||||
is_override: bool = False, retriever_from: str = 'dev'):
|
||||
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
|
||||
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
|
||||
auto_generate_name: bool = True):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
query = PromptBuilder.process_template(query)
|
||||
query = PromptTemplateParser.remove_template_variables(query)
|
||||
|
||||
memory = None
|
||||
if conversation:
|
||||
@ -59,16 +66,21 @@ class Completion:
|
||||
is_override=is_override,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance
|
||||
model_instance=final_model_instance,
|
||||
auto_generate_name=auto_generate_name
|
||||
)
|
||||
|
||||
prompt_message_files = [file.prompt_message_file for file in files]
|
||||
|
||||
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
||||
mode=app.mode,
|
||||
model_instance=final_model_instance,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs
|
||||
inputs=inputs,
|
||||
files=prompt_message_files
|
||||
)
|
||||
|
||||
# init orchestrator rule parser
|
||||
@ -78,26 +90,36 @@ class Completion:
|
||||
)
|
||||
|
||||
try:
|
||||
# parse sensitive_word_avoidance_chain
|
||||
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
||||
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
|
||||
final_model_instance, [chain_callback])
|
||||
if sensitive_word_avoidance_chain:
|
||||
try:
|
||||
query = sensitive_word_avoidance_chain.run(query)
|
||||
except SensitiveWordAvoidanceError as ex:
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=ex.message
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
|
||||
except ModerationException as e:
|
||||
cls.run_final_llm(
|
||||
model_instance=final_model_instance,
|
||||
mode=app.mode,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=str(e)
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_model_config.external_data_tools_list
|
||||
if external_data_tools:
|
||||
inputs = cls.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# get agent executor
|
||||
agent_executor = orchestrator_rule_parser.to_agent_executor(
|
||||
@ -132,42 +154,152 @@ class Completion:
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
fake_response=fake_response
|
||||
)
|
||||
except ConversationTaskStoppedException:
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
return
|
||||
except ChunkedEncodingError as e:
|
||||
# Interrupt by LLM (like OpenAI), handle it.
|
||||
logging.warning(f'ChunkedEncodingError: {e}')
|
||||
conversation_message_task.end()
|
||||
return
|
||||
|
||||
|
||||
@classmethod
|
||||
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
|
||||
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
|
||||
return inputs, query
|
||||
|
||||
type = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
|
||||
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
|
||||
moderation_result = moderation.moderation_for_inputs(inputs, query)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
return inputs, query
|
||||
|
||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
raise ModerationException(moderation_result.preset_response)
|
||||
elif moderation_result.action == ModerationAction.OVERRIDED:
|
||||
inputs = moderation_result.inputs
|
||||
query = moderation_result.query
|
||||
|
||||
return inputs, query
|
||||
|
||||
@classmethod
|
||||
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
|
||||
inputs: dict, query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
# Group tools by type and config
|
||||
grouped_tools = {}
|
||||
for tool in external_data_tools:
|
||||
if not tool.get("enabled"):
|
||||
continue
|
||||
|
||||
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
|
||||
grouped_tools.setdefault(tool_key, []).append(tool)
|
||||
|
||||
results = {}
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = {}
|
||||
for tool in external_data_tools:
|
||||
if not tool.get("enabled"):
|
||||
continue
|
||||
|
||||
future = executor.submit(
|
||||
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
|
||||
inputs, query
|
||||
)
|
||||
|
||||
futures[future] = tool
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
tool_variable, result = future.result()
|
||||
results[tool_variable] = result
|
||||
|
||||
inputs.update(results)
|
||||
return inputs
|
||||
|
||||
@classmethod
|
||||
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
|
||||
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
with flask_app.app_context():
|
||||
tool_variable = external_data_tool.get("variable")
|
||||
tool_type = external_data_tool.get("type")
|
||||
tool_config = external_data_tool.get("config")
|
||||
|
||||
external_data_tool_factory = ExternalDataToolFactory(
|
||||
name=tool_type,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
variable=tool_variable,
|
||||
config=tool_config
|
||||
)
|
||||
|
||||
# query external data tool
|
||||
result = external_data_tool_factory.query(
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
return tool_variable, result
|
||||
|
||||
@classmethod
|
||||
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
|
||||
if app.mode != 'completion':
|
||||
return query
|
||||
|
||||
|
||||
return inputs.get(app_model_config.dataset_query_variable, "")
|
||||
|
||||
@classmethod
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
|
||||
inputs: dict,
|
||||
files: List[PromptMessageFile],
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
||||
fake_response: Optional[str]):
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
# get llm prompt
|
||||
prompt_messages, stop_words = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory
|
||||
)
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, stop_words = prompt_transform.get_prompt(
|
||||
app_mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
)
|
||||
else:
|
||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
model_config = app_model_config.model_dict
|
||||
completion_params = model_config.get("completion_params", {})
|
||||
stop_words = completion_params.get("stop", [])
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
model_instance=model_instance,
|
||||
@ -176,7 +308,7 @@ class Completion:
|
||||
|
||||
response = model_instance.run(
|
||||
messages=prompt_messages,
|
||||
stop=stop_words,
|
||||
stop=stop_words if stop_words else None,
|
||||
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
|
||||
fake_response=fake_response
|
||||
)
|
||||
@ -217,7 +349,7 @@ class Completion:
|
||||
|
||||
@classmethod
|
||||
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict) -> int:
|
||||
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
|
||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
||||
|
||||
@ -227,15 +359,31 @@ class Completion:
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
# get prompt without memory and context
|
||||
prompt_messages, _ = model_instance.get_prompt(
|
||||
mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
context=None,
|
||||
memory=None
|
||||
)
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, _ = prompt_transform.get_prompt(
|
||||
app_mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
)
|
||||
else:
|
||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||
app_mode=mode,
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
|
||||
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
|
||||
@ -266,52 +414,3 @@ class Completion:
|
||||
model_kwargs = model_instance.get_model_kwargs()
|
||||
model_kwargs.max_tokens = max_tokens
|
||||
model_instance.set_model_kwargs(model_kwargs)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||
|
||||
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
|
||||
tenant_id=app.tenant_id,
|
||||
model_config=app_model_config.model_dict,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
# get llm prompt
|
||||
old_prompt_messages, _ = final_model_instance.get_prompt(
|
||||
mode='completion',
|
||||
pre_prompt=pre_prompt,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
context=None,
|
||||
memory=None
|
||||
)
|
||||
|
||||
original_completion = message.answer.strip()
|
||||
|
||||
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
|
||||
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
conversation_message_task = ConversationMessageTask(
|
||||
task_id=task_id,
|
||||
app=app,
|
||||
app_model_config=app_model_config,
|
||||
user=user,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
is_override=True if message.override_model_configs else False,
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance
|
||||
)
|
||||
|
||||
cls.recale_llm_max_tokens(
|
||||
model_instance=final_model_instance,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
final_model_instance.run(
|
||||
messages=prompt_messages,
|
||||
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
|
||||
)
|
||||
|
||||
@ -6,23 +6,25 @@ from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DatasetQuery
|
||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
|
||||
MessageChain, DatasetRetrieverResource
|
||||
MessageChain, DatasetRetrieverResource, MessageFile
|
||||
|
||||
|
||||
class ConversationMessageTask:
|
||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
|
||||
conversation: Optional[Conversation] = None, is_override: bool = False):
|
||||
inputs: dict, query: str, files: List[FileObj], streaming: bool,
|
||||
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
|
||||
auto_generate_name: bool = True):
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
self.task_id = task_id
|
||||
@ -35,6 +37,7 @@ class ConversationMessageTask:
|
||||
self.user = user
|
||||
self.inputs = inputs
|
||||
self.query = query
|
||||
self.files = files
|
||||
self.streaming = streaming
|
||||
|
||||
self.conversation = conversation
|
||||
@ -45,6 +48,7 @@ class ConversationMessageTask:
|
||||
self.message = None
|
||||
|
||||
self.retriever_resource = None
|
||||
self.auto_generate_name = auto_generate_name
|
||||
|
||||
self.model_dict = self.app_model_config.model_dict
|
||||
self.provider_name = self.model_dict.get('provider')
|
||||
@ -74,10 +78,10 @@ class ConversationMessageTask:
|
||||
if self.mode == 'chat':
|
||||
introduction = self.app_model_config.opening_statement
|
||||
if introduction:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
|
||||
prompt_template = PromptTemplateParser(template=introduction)
|
||||
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
|
||||
try:
|
||||
introduction = prompt_template.format(**prompt_inputs)
|
||||
introduction = prompt_template.format(prompt_inputs)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@ -100,7 +104,7 @@ class ConversationMessageTask:
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=self.mode,
|
||||
name='',
|
||||
name='New conversation',
|
||||
inputs=self.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction=system_instruction,
|
||||
@ -142,6 +146,19 @@ class ConversationMessageTask:
|
||||
db.session.add(self.message)
|
||||
db.session.commit()
|
||||
|
||||
for file in self.files:
|
||||
message_file = MessageFile(
|
||||
message_id=self.message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
if text is not None:
|
||||
self._pub_handler.pub_text(text)
|
||||
@ -150,12 +167,12 @@ class ConversationMessageTask:
|
||||
message_tokens = llm_message.prompt_tokens
|
||||
answer_tokens = llm_message.completion_tokens
|
||||
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
|
||||
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
|
||||
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
|
||||
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
|
||||
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
|
||||
total_price = message_total_price + answer_total_price
|
||||
|
||||
@ -163,7 +180,7 @@ class ConversationMessageTask:
|
||||
self.message.message_tokens = message_tokens
|
||||
self.message.message_unit_price = message_unit_price
|
||||
self.message.message_price_unit = message_price_unit
|
||||
self.message.answer = PromptBuilder.process_template(
|
||||
self.message.answer = PromptTemplateParser.remove_template_variables(
|
||||
llm_message.completion.strip()) if llm_message.completion else ''
|
||||
self.message.answer_tokens = answer_tokens
|
||||
self.message.answer_unit_price = answer_unit_price
|
||||
@ -176,7 +193,8 @@ class ConversationMessageTask:
|
||||
message_was_created.send(
|
||||
self.message,
|
||||
conversation=self.conversation,
|
||||
is_first_message=self.is_new_conversation
|
||||
is_first_message=self.is_new_conversation,
|
||||
auto_generate_name=self.auto_generate_name
|
||||
)
|
||||
|
||||
if not by_stopped:
|
||||
@ -226,15 +244,15 @@ class ConversationMessageTask:
|
||||
|
||||
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
|
||||
agent_loop: AgentLoop):
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
|
||||
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
|
||||
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
|
||||
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
|
||||
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
|
||||
|
||||
loop_message_tokens = agent_loop.prompt_tokens
|
||||
loop_answer_tokens = agent_loop.completion_tokens
|
||||
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
|
||||
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
|
||||
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
|
||||
loop_total_price = loop_message_total_price + loop_answer_total_price
|
||||
|
||||
@ -290,6 +308,10 @@ class ConversationMessageTask:
|
||||
db.session.commit()
|
||||
self.retriever_resource = resource
|
||||
|
||||
def on_message_replace(self, text: str):
|
||||
if text is not None:
|
||||
self._pub_handler.pub_message_replace(text)
|
||||
|
||||
def message_end(self):
|
||||
self._pub_handler.pub_message_end(self.retriever_resource)
|
||||
|
||||
@ -342,6 +364,24 @@ class PubHandler:
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_message_replace(self, text: str):
|
||||
content = {
|
||||
'event': 'message_replace',
|
||||
'data': {
|
||||
'task_id': self._task_id,
|
||||
'message_id': str(self._message.id),
|
||||
'text': text,
|
||||
'mode': self._conversation.mode,
|
||||
'conversation_id': str(self._conversation.id)
|
||||
}
|
||||
}
|
||||
|
||||
redis_client.publish(self._channel, json.dumps(content))
|
||||
|
||||
if self._is_stopped():
|
||||
self.pub_end()
|
||||
raise ConversationTaskStoppedException()
|
||||
|
||||
def pub_chain(self, message_chain: MessageChain):
|
||||
if self._chain_pub:
|
||||
content = {
|
||||
@ -443,3 +483,7 @@ class PubHandler:
|
||||
|
||||
class ConversationTaskStoppedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationTaskInterruptException(Exception):
|
||||
pass
|
||||
|
||||
0
api/core/extension/__init__.py
Normal file
0
api/core/extension/__init__.py
Normal file
62
api/core/extension/api_based_extension_requestor.py
Normal file
62
api/core/extension/api_based_extension_requestor.py
Normal file
@ -0,0 +1,62 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from models.api_based_extension import APIBasedExtensionPoint
|
||||
|
||||
|
||||
class APIBasedExtensionRequestor:
|
||||
timeout: (int, int) = (5, 60)
|
||||
"""timeout for request connect and read"""
|
||||
|
||||
def __init__(self, api_endpoint: str, api_key: str) -> None:
|
||||
self.api_endpoint = api_endpoint
|
||||
self.api_key = api_key
|
||||
|
||||
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
|
||||
"""
|
||||
Request the api.
|
||||
|
||||
:param point: the api point
|
||||
:param params: the request params
|
||||
:return: the response json
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer {}".format(self.api_key)
|
||||
}
|
||||
|
||||
url = self.api_endpoint
|
||||
|
||||
try:
|
||||
# proxy support for security
|
||||
proxies = None
|
||||
if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"):
|
||||
proxies = {
|
||||
'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"),
|
||||
'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"),
|
||||
}
|
||||
|
||||
response = requests.request(
|
||||
method='POST',
|
||||
url=url,
|
||||
json={
|
||||
'point': point.value,
|
||||
'params': params
|
||||
},
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
proxies=proxies
|
||||
)
|
||||
except requests.exceptions.Timeout:
|
||||
raise ValueError("request timeout")
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise ValueError("request connection error")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError("request error, status_code: {}, content: {}".format(
|
||||
response.status_code,
|
||||
response.text[:100]
|
||||
))
|
||||
|
||||
return response.json()
|
||||
111
api/core/extension/extensible.py
Normal file
111
api/core/extension/extensible.py
Normal file
@ -0,0 +1,111 @@
|
||||
import enum
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExtensionModule(enum.Enum):
|
||||
MODERATION = 'moderation'
|
||||
EXTERNAL_DATA_TOOL = 'external_data_tool'
|
||||
|
||||
|
||||
class ModuleExtension(BaseModel):
|
||||
extension_class: Any
|
||||
name: str
|
||||
label: Optional[dict] = None
|
||||
form_schema: Optional[list] = None
|
||||
builtin: bool = True
|
||||
position: Optional[int] = None
|
||||
|
||||
|
||||
class Extensible:
|
||||
module: ExtensionModule
|
||||
|
||||
name: str
|
||||
tenant_id: str
|
||||
config: Optional[dict] = None
|
||||
|
||||
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def scan_extensions(cls):
|
||||
extensions = {}
|
||||
|
||||
# get the path of the current class
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
|
||||
current_dir_path = os.path.dirname(current_path)
|
||||
|
||||
# traverse subdirectories
|
||||
for subdir_name in os.listdir(current_dir_path):
|
||||
if subdir_name.startswith('__'):
|
||||
continue
|
||||
|
||||
subdir_path = os.path.join(current_dir_path, subdir_name)
|
||||
extension_name = subdir_name
|
||||
if os.path.isdir(subdir_path):
|
||||
file_names = os.listdir(subdir_path)
|
||||
|
||||
# is builtin extension, builtin extension
|
||||
# in the front-end page and business logic, there are special treatments.
|
||||
builtin = False
|
||||
position = None
|
||||
if '__builtin__' in file_names:
|
||||
builtin = True
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, 'r') as f:
|
||||
position = int(f.read().strip())
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||
py_path = os.path.join(subdir_path, extension_name + '.py')
|
||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
extension_class = obj
|
||||
break
|
||||
|
||||
if not extension_class:
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
||||
continue
|
||||
|
||||
json_data = {}
|
||||
if not builtin:
|
||||
if 'schema.json' not in file_names:
|
||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
json_path = os.path.join(subdir_path, 'schema.json')
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
extensions[extension_name] = ModuleExtension(
|
||||
extension_class=extension_class,
|
||||
name=extension_name,
|
||||
label=json_data.get('label'),
|
||||
form_schema=json_data.get('form_schema'),
|
||||
builtin=builtin,
|
||||
position=position
|
||||
)
|
||||
|
||||
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
|
||||
sorted_extensions = OrderedDict(sorted_items)
|
||||
|
||||
return sorted_extensions
|
||||
47
api/core/extension/extension.py
Normal file
47
api/core/extension/extension.py
Normal file
@ -0,0 +1,47 @@
|
||||
from core.extension.extensible import ModuleExtension, ExtensionModule
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
from core.moderation.base import Moderation
|
||||
|
||||
|
||||
class Extension:
|
||||
__module_extensions: dict[str, dict[str, ModuleExtension]] = {}
|
||||
|
||||
module_classes = {
|
||||
ExtensionModule.MODERATION: Moderation,
|
||||
ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
|
||||
}
|
||||
|
||||
def init(self):
|
||||
for module, module_class in self.module_classes.items():
|
||||
self.__module_extensions[module.value] = module_class.scan_extensions()
|
||||
|
||||
def module_extensions(self, module: str) -> list[ModuleExtension]:
|
||||
module_extensions = self.__module_extensions.get(module)
|
||||
|
||||
if not module_extensions:
|
||||
raise ValueError(f"Extension Module {module} not found")
|
||||
|
||||
return list(module_extensions.values())
|
||||
|
||||
def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension:
|
||||
module_extensions = self.__module_extensions.get(module.value)
|
||||
|
||||
if not module_extensions:
|
||||
raise ValueError(f"Extension Module {module} not found")
|
||||
|
||||
module_extension = module_extensions.get(extension_name)
|
||||
|
||||
if not module_extension:
|
||||
raise ValueError(f"Extension {extension_name} not found")
|
||||
|
||||
return module_extension
|
||||
|
||||
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
|
||||
module_extension = self.module_extension(module, extension_name)
|
||||
return module_extension.extension_class
|
||||
|
||||
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
|
||||
module_extension = self.module_extension(module, extension_name)
|
||||
form_schema = module_extension.form_schema
|
||||
|
||||
# TODO validate form_schema
|
||||
0
api/core/external_data_tool/__init__.py
Normal file
0
api/core/external_data_tool/__init__.py
Normal file
1
api/core/external_data_tool/api/__builtin__
Normal file
1
api/core/external_data_tool/api/__builtin__
Normal file
@ -0,0 +1 @@
|
||||
1
|
||||
0
api/core/external_data_tool/api/__init__.py
Normal file
0
api/core/external_data_tool/api/__init__.py
Normal file
92
api/core/external_data_tool/api/api.py
Normal file
92
api/core/external_data_tool/api/api.py
Normal file
@ -0,0 +1,92 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
|
||||
|
||||
class ApiExternalDataTool(ExternalDataTool):
|
||||
"""
|
||||
The api external data tool.
|
||||
"""
|
||||
|
||||
name: str = "api"
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
# own validation logic
|
||||
api_based_extension_id = config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
raise ValueError("api_based_extension_id is required")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
# get params from config
|
||||
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = db.session.query(APIBasedExtension).filter(
|
||||
APIBasedExtension.tenant_id == self.tenant_id,
|
||||
APIBasedExtension.id == api_based_extension_id
|
||||
).first()
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid"
|
||||
.format(self.config.get('variable')))
|
||||
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id,
|
||||
token=api_based_extension.api_key
|
||||
)
|
||||
|
||||
try:
|
||||
# request api
|
||||
requestor = APIBasedExtensionRequestor(
|
||||
api_endpoint=api_based_extension.api_endpoint,
|
||||
api_key=api_key
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
|
||||
self.config.get('variable'),
|
||||
e
|
||||
))
|
||||
|
||||
response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
|
||||
'app_id': self.app_id,
|
||||
'tool_variable': self.variable,
|
||||
'inputs': inputs,
|
||||
'query': query
|
||||
})
|
||||
|
||||
if 'result' not in response_json:
|
||||
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
|
||||
.format(self.config.get('variable')))
|
||||
|
||||
return response_json['result']
|
||||
45
api/core/external_data_tool/base.py
Normal file
45
api/core/external_data_tool/base.py
Normal file
@ -0,0 +1,45 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Optional
|
||||
|
||||
from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
|
||||
class ExternalDataTool(Extensible, ABC):
|
||||
"""
|
||||
The base class of external data tool.
|
||||
"""
|
||||
|
||||
module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL
|
||||
|
||||
app_id: str
|
||||
"""the id of app"""
|
||||
variable: str
|
||||
"""the tool variable name of app tool"""
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None:
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
self.variable = variable
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
40
api/core/external_data_tool/factory.py
Normal file
40
api/core/external_data_tool/factory.py
Normal file
@ -0,0 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class ExternalDataToolFactory:
|
||||
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
self.__extension_instance = extension_class(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
variable=variable,
|
||||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
:param name: the name of external data tool
|
||||
:param tenant_id: the id of workspace
|
||||
:param config: the form config data
|
||||
:return:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
extension_class.validate_config(tenant_id, config)
|
||||
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
:param inputs: user inputs
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
return self.__extension_instance.query(inputs, query)
|
||||
0
api/core/file/__init__.py
Normal file
0
api/core/file/__init__.py
Normal file
79
api/core/file/file_obj.py
Normal file
79
api/core/file/file_obj.py
Normal file
@ -0,0 +1,79 @@
|
||||
import enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class FileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = 'remote_url'
|
||||
LOCAL_FILE = 'local_file'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileObj(BaseModel):
|
||||
id: Optional[str]
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
url: Optional[str]
|
||||
upload_file_id: Optional[str]
|
||||
file_config: dict
|
||||
|
||||
@property
|
||||
def data(self) -> Optional[str]:
|
||||
return self._get_data()
|
||||
|
||||
@property
|
||||
def preview_url(self) -> Optional[str]:
|
||||
return self._get_data(force_url=True)
|
||||
|
||||
@property
|
||||
def prompt_message_file(self) -> PromptMessageFile:
|
||||
if self.type == FileType.IMAGE:
|
||||
image_config = self.file_config.get('image')
|
||||
|
||||
return ImagePromptMessageFile(
|
||||
data=self.data,
|
||||
detail=ImagePromptMessageFile.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == self.upload_file_id,
|
||||
UploadFile.tenant_id == self.tenant_id
|
||||
).first())
|
||||
|
||||
return UploadFileParser.get_image_data(
|
||||
upload_file=upload_file,
|
||||
force_url=force_url
|
||||
)
|
||||
|
||||
return None
|
||||
180
api/core/file/message_file_parser.py
Normal file
180
api/core/file/message_file_parser.py
Normal file
@ -0,0 +1,180 @@
|
||||
from typing import List, Union, Optional, Dict
|
||||
|
||||
import requests
|
||||
|
||||
from core.file.file_obj import FileObj, FileType, FileTransferMethod
|
||||
from core.file.upload_file_parser import SUPPORT_EXTENSIONS
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import MessageFile, EndUser, AppModelConfig, UploadFile
|
||||
|
||||
|
||||
class MessageFileParser:
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> List[FileObj]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
:param files:
|
||||
:param app_model_config:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
file_upload_config = app_model_config.file_upload_dict
|
||||
|
||||
for file in files:
|
||||
if not isinstance(file, dict):
|
||||
raise ValueError('Invalid file format, must be dict')
|
||||
if not file.get('type'):
|
||||
raise ValueError('Missing file type')
|
||||
FileType.value_of(file.get('type'))
|
||||
if not file.get('transfer_method'):
|
||||
raise ValueError('Missing file transfer method')
|
||||
FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
|
||||
if not file.get('url'):
|
||||
raise ValueError('Missing file url')
|
||||
if not file.get('url').startswith('http'):
|
||||
raise ValueError('Invalid file url')
|
||||
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
|
||||
raise ValueError('Missing file upload_file_id')
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_upload_config)
|
||||
|
||||
# validate files
|
||||
new_files = []
|
||||
for file_type, file_objs in type_file_objs.items():
|
||||
if file_type == FileType.IMAGE:
|
||||
# parse and validate files
|
||||
image_config = file_upload_config.get('image')
|
||||
|
||||
# check if image file feature is enabled
|
||||
if not image_config['enabled']:
|
||||
continue
|
||||
|
||||
# Validate number of files
|
||||
if len(files) > image_config['number_limits']:
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Validate transfer method
|
||||
if file_obj.transfer_method.value not in image_config['transfer_methods']:
|
||||
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
|
||||
|
||||
# Validate file type
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
raise ValueError(f'Invalid file type: {file_obj.type}')
|
||||
|
||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
# check remote url valid and is image
|
||||
result, error = self._check_image_remote_url(file_obj.url)
|
||||
if result is False:
|
||||
raise ValueError(error)
|
||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
# get upload file from upload_file_id
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.upload_file_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
UploadFile.extension.in_(SUPPORT_EXTENSIONS)
|
||||
).first())
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
if not upload_file:
|
||||
raise ValueError('Invalid upload file')
|
||||
|
||||
new_files.append(file_obj)
|
||||
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
|
||||
"""
|
||||
transform message files
|
||||
|
||||
:param files:
|
||||
:param app_model_config:
|
||||
:return:
|
||||
"""
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
|
||||
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
|
||||
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
:param files:
|
||||
:param file_upload_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: Dict[FileType, List[FileObj]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
||||
if not files:
|
||||
return type_file_objs
|
||||
|
||||
# group by file type and convert file args or message files to FileObj
|
||||
for file in files:
|
||||
file_obj = self._to_file_obj(file, file_upload_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
type_file_objs[file_obj.type].append(file_obj)
|
||||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
:param file:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
return FileObj(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get('type')),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
file_config=file_upload_config
|
||||
)
|
||||
else:
|
||||
return FileObj(
|
||||
id=file.id,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.type),
|
||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id or None,
|
||||
file_config=file_upload_config
|
||||
)
|
||||
|
||||
def _check_image_remote_url(self, url):
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "URL does not exist."
|
||||
except requests.RequestException as e:
|
||||
return False, f"Error checking URL: {e}"
|
||||
79
api/core/file/upload_file_parser.py
Normal file
79
api/core/file/upload_file_parser.py
Normal file
@ -0,0 +1,79 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in SUPPORT_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.error(f'File not found: {upload_file.key}')
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode('utf-8')
|
||||
return f'data:{upload_file.mime_type};base64,{encoded_string}'
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file: UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= 300 # expired after 5 minutes
|
||||
@ -10,14 +10,13 @@ from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
|
||||
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
|
||||
GENERATOR_QA_PROMPT
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
||||
def generate_conversation_name(cls, tenant_id: str, query):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
|
||||
if len(query) > 2000:
|
||||
@ -41,81 +40,26 @@ class LLMGenerator:
|
||||
|
||||
result_dict = json.loads(answer)
|
||||
answer = result_dict['Your Output']
|
||||
name = answer.strip()
|
||||
|
||||
return answer.strip()
|
||||
if len(name) > 75:
|
||||
name = name[:75] + '...'
|
||||
|
||||
@classmethod
|
||||
def generate_conversation_summary(cls, tenant_id: str, messages):
|
||||
max_tokens = 200
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
model_kwargs=ModelKwargs(
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
)
|
||||
|
||||
prompt = CONVERSATION_SUMMARY_PROMPT
|
||||
prompt_with_empty_context = prompt.format(context='')
|
||||
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
|
||||
max_context_token_length = model_instance.model_rules.max_tokens.max
|
||||
max_context_token_length = max_context_token_length if max_context_token_length else 1500
|
||||
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
|
||||
|
||||
context = ''
|
||||
for message in messages:
|
||||
if not message.answer:
|
||||
continue
|
||||
|
||||
if len(message.query) > 2000:
|
||||
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
|
||||
else:
|
||||
query = message.query
|
||||
|
||||
if len(message.answer) > 2000:
|
||||
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
|
||||
else:
|
||||
answer = message.answer
|
||||
|
||||
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
|
||||
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
|
||||
context += message_qa_text
|
||||
|
||||
if not context:
|
||||
return '[message too long, no summary]'
|
||||
|
||||
prompt = prompt.format(context=context)
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
|
||||
prompt = INTRODUCTION_GENERATE_PROMPT
|
||||
prompt = prompt.format(prompt=pre_prompt)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=prompt)]
|
||||
response = model_instance.run(prompts)
|
||||
answer = response.content
|
||||
return answer.strip()
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt = JinjaPromptTemplate(
|
||||
template="{{histories}}\n{{format_instructions}}\nquestions:\n",
|
||||
input_variables=["histories"],
|
||||
partial_variables={"format_instructions": format_instructions}
|
||||
prompt_template = PromptTemplateParser(
|
||||
template="{{histories}}\n{{format_instructions}}\nquestions:\n"
|
||||
)
|
||||
|
||||
_input = prompt.format_prompt(histories=histories)
|
||||
prompt = prompt_template.format({
|
||||
"histories": histories,
|
||||
"format_instructions": format_instructions
|
||||
})
|
||||
|
||||
try:
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
@ -128,10 +72,10 @@ class LLMGenerator:
|
||||
except ProviderTokenNotInitError:
|
||||
return []
|
||||
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompts)
|
||||
output = model_instance.run(prompt_messages)
|
||||
questions = output_parser.parse(output.content)
|
||||
except LLMError:
|
||||
questions = []
|
||||
@ -145,19 +89,21 @@ class LLMGenerator:
|
||||
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
|
||||
output_parser = RuleConfigGeneratorOutputParser()
|
||||
|
||||
prompt = OutLinePromptTemplate(
|
||||
template=output_parser.get_format_instructions(),
|
||||
input_variables=["audiences", "hoping_to_solve"],
|
||||
partial_variables={
|
||||
"variable": '{variable}',
|
||||
"lanA": '{lanA}',
|
||||
"lanB": '{lanB}',
|
||||
"topic": '{topic}'
|
||||
},
|
||||
validate_template=False
|
||||
prompt_template = PromptTemplateParser(
|
||||
template=output_parser.get_format_instructions()
|
||||
)
|
||||
|
||||
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"audiences": audiences,
|
||||
"hoping_to_solve": hoping_to_solve,
|
||||
"variable": "{{variable}}",
|
||||
"lanA": "{{lanA}}",
|
||||
"lanB": "{{lanB}}",
|
||||
"topic": "{{topic}}"
|
||||
},
|
||||
remove_template_variables=False
|
||||
)
|
||||
|
||||
model_instance = ModelFactory.get_text_generation_model(
|
||||
tenant_id=tenant_id,
|
||||
@ -167,10 +113,10 @@ class LLMGenerator:
|
||||
)
|
||||
)
|
||||
|
||||
prompts = [PromptMessage(content=_input.to_string())]
|
||||
prompt_messages = [PromptMessage(content=prompt)]
|
||||
|
||||
try:
|
||||
output = model_instance.run(prompts)
|
||||
output = model_instance.run(prompt_messages)
|
||||
rule_config = output_parser.parse(output.content)
|
||||
except LLMError as e:
|
||||
raise e
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import openai
|
||||
|
||||
@ -16,19 +17,20 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
|
||||
length = 2000
|
||||
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
|
||||
|
||||
max_text_chunks = 32
|
||||
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
||||
if len(text_chunks) == 0:
|
||||
return True
|
||||
|
||||
for text_chunk in chunks:
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
text_chunk = random.choice(text_chunks)
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
try:
|
||||
moderation_result = openai.Moderation.create(input=text_chunk,
|
||||
api_key=hosted_model_providers.openai.api_key)
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result['flagged'] is True:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
858
api/core/index/vector_index/milvus.py
Normal file
858
api/core/index/vector_index/milvus.py
Normal file
@ -0,0 +1,858 @@
|
||||
"""Wrapper around the Milvus vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union, Sequence
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MILVUS_CONNECTION = {
|
||||
"host": "localhost",
|
||||
"port": "19530",
|
||||
"user": "",
|
||||
"password": "",
|
||||
"secure": False,
|
||||
}
|
||||
|
||||
|
||||
class Milvus(VectorStore):
|
||||
"""Initialize wrapper around the milvus vector database.
|
||||
|
||||
In order to use this you need to have `pymilvus` installed and a
|
||||
running Milvus
|
||||
|
||||
See the following documentation for how to run a Milvus instance:
|
||||
https://milvus.io/docs/install_standalone-docker.md
|
||||
|
||||
If looking for a hosted Milvus, take a look at this documentation:
|
||||
https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
|
||||
this project,
|
||||
|
||||
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
|
||||
|
||||
Args:
|
||||
embedding_function (Embeddings): Function used to embed the text.
|
||||
collection_name (str): Which Milvus collection to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (Optional[dict[str, any]]): The connection args used for
|
||||
this class comes in the form of a dict.
|
||||
consistency_level (str): The consistency level to use for a collection.
|
||||
Defaults to "Session".
|
||||
index_params (Optional[dict]): Which index params to use. Defaults to
|
||||
HNSW/AUTOINDEX depending on service.
|
||||
search_params (Optional[dict]): Which search params to use. Defaults to
|
||||
default of index.
|
||||
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
|
||||
to False.
|
||||
|
||||
The connection args used for this class comes in the form of a dict,
|
||||
here are a few of the options:
|
||||
address (str): The actual address of Milvus
|
||||
instance. Example address: "localhost:19530"
|
||||
uri (str): The uri of Milvus instance. Example uri:
|
||||
"http://randomwebsite:19530",
|
||||
"tcp:foobarsite:19530",
|
||||
"https://ok.s3.south.com:19530".
|
||||
host (str): The host of Milvus instance. Default at "localhost",
|
||||
PyMilvus will fill in the default host if only port is provided.
|
||||
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
|
||||
will fill in the default port if only host is provided.
|
||||
user (str): Use which user to connect to Milvus instance. If user and
|
||||
password are provided, we will add related header in every RPC call.
|
||||
password (str): Required when user is provided. The password
|
||||
corresponding to the user.
|
||||
secure (bool): Default is false. If set to true, tls will be enabled.
|
||||
client_key_path (str): If use tls two-way authentication, need to
|
||||
write the client.key path.
|
||||
client_pem_path (str): If use tls two-way authentication, need to
|
||||
write the client.pem path.
|
||||
ca_pem_path (str): If use tls two-way authentication, need to write
|
||||
the ca.pem path.
|
||||
server_pem_path (str): If use tls one-way authentication, need to
|
||||
write the server.pem path.
|
||||
server_name (str): If use tls, need to write the common name.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import Milvus
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embedding = OpenAIEmbeddings()
|
||||
# Connect to a milvus instance on localhost
|
||||
milvus_store = Milvus(
|
||||
embedding_function = Embeddings,
|
||||
collection_name = "LangChainCollection",
|
||||
drop_old = True,
|
||||
)
|
||||
|
||||
Raises:
|
||||
ValueError: If the pymilvus python package is not installed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Embeddings,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: Optional[dict[str, Any]] = None,
|
||||
consistency_level: str = "Session",
|
||||
index_params: Optional[dict] = None,
|
||||
search_params: Optional[dict] = None,
|
||||
drop_old: Optional[bool] = False,
|
||||
):
|
||||
"""Initialize the Milvus vector store."""
|
||||
try:
|
||||
from pymilvus import Collection, utility
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
|
||||
# Default search params when one is not provided.
|
||||
self.default_search_params = {
|
||||
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"HNSW": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
|
||||
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
|
||||
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
|
||||
"AUTOINDEX": {"metric_type": "L2", "params": {}},
|
||||
}
|
||||
|
||||
self.embedding_func = embedding_function
|
||||
self.collection_name = collection_name
|
||||
self.index_params = index_params
|
||||
self.search_params = search_params
|
||||
self.consistency_level = consistency_level
|
||||
|
||||
# In order for a collection to be compatible, pk needs to be auto'id and int
|
||||
self._primary_field = "id"
|
||||
# In order for compatibility, the text field will need to be called "text"
|
||||
self._text_field = "page_content"
|
||||
# In order for compatibility, the vector field needs to be called "vector"
|
||||
self._vector_field = "vectors"
|
||||
# In order for compatibility, the metadata field will need to be called "metadata"
|
||||
self._metadata_field = "metadata"
|
||||
self.fields: list[str] = []
|
||||
# Create the connection to the server
|
||||
if connection_args is None:
|
||||
connection_args = DEFAULT_MILVUS_CONNECTION
|
||||
self.alias = self._create_connection_alias(connection_args)
|
||||
self.col: Optional[Collection] = None
|
||||
|
||||
# Grab the existing collection if it exists
|
||||
if utility.has_collection(self.collection_name, using=self.alias):
|
||||
self.col = Collection(
|
||||
self.collection_name,
|
||||
using=self.alias,
|
||||
)
|
||||
# If need to drop old, drop it
|
||||
if drop_old and isinstance(self.col, Collection):
|
||||
self.col.drop()
|
||||
self.col = None
|
||||
|
||||
# Initialize the vector store
|
||||
self._init()
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
return self.embedding_func
|
||||
|
||||
def _create_connection_alias(self, connection_args: dict) -> str:
|
||||
"""Create the connection to the Milvus server."""
|
||||
from pymilvus import MilvusException, connections
|
||||
|
||||
# Grab the connection arguments that are used for checking existing connection
|
||||
host: str = connection_args.get("host", None)
|
||||
port: Union[str, int] = connection_args.get("port", None)
|
||||
address: str = connection_args.get("address", None)
|
||||
uri: str = connection_args.get("uri", None)
|
||||
user = connection_args.get("user", None)
|
||||
|
||||
# Order of use is host/port, uri, address
|
||||
if host is not None and port is not None:
|
||||
given_address = str(host) + ":" + str(port)
|
||||
elif uri is not None:
|
||||
given_address = uri.split("https://")[1]
|
||||
elif address is not None:
|
||||
given_address = address
|
||||
else:
|
||||
given_address = None
|
||||
logger.debug("Missing standard address type for reuse atttempt")
|
||||
|
||||
# User defaults to empty string when getting connection info
|
||||
if user is not None:
|
||||
tmp_user = user
|
||||
else:
|
||||
tmp_user = ""
|
||||
|
||||
# If a valid address was given, then check if a connection exists
|
||||
if given_address is not None:
|
||||
for con in connections.list_connections():
|
||||
addr = connections.get_connection_addr(con[0])
|
||||
if (
|
||||
con[1]
|
||||
and ("address" in addr)
|
||||
and (addr["address"] == given_address)
|
||||
and ("user" in addr)
|
||||
and (addr["user"] == tmp_user)
|
||||
):
|
||||
logger.debug("Using previous connection: %s", con[0])
|
||||
return con[0]
|
||||
|
||||
# Generate a new connection if one doesn't exist
|
||||
alias = uuid4().hex
|
||||
try:
|
||||
connections.connect(alias=alias, **connection_args)
|
||||
logger.debug("Created new connection using: %s", alias)
|
||||
return alias
|
||||
except MilvusException as e:
|
||||
logger.error("Failed to create new connection using: %s", alias)
|
||||
raise e
|
||||
|
||||
def _init(
|
||||
self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
|
||||
) -> None:
|
||||
if embeddings is not None:
|
||||
self._create_collection(embeddings, metadatas)
|
||||
self._extract_fields()
|
||||
self._create_index()
|
||||
self._create_search_params()
|
||||
self._load()
|
||||
|
||||
def _create_collection(
|
||||
self, embeddings: list, metadatas: Optional[list[dict]] = None
|
||||
) -> None:
|
||||
from pymilvus import (
|
||||
Collection,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
MilvusException,
|
||||
)
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
|
||||
# Determine embedding dim
|
||||
dim = len(embeddings[0])
|
||||
fields = []
|
||||
# Determine metadata schema
|
||||
# if metadatas:
|
||||
# # Create FieldSchema for each entry in metadata.
|
||||
# for key, value in metadatas[0].items():
|
||||
# # Infer the corresponding datatype of the metadata
|
||||
# dtype = infer_dtype_bydata(value)
|
||||
# # Datatype isn't compatible
|
||||
# if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
|
||||
# logger.error(
|
||||
# "Failure to create collection, unrecognized dtype for key: %s",
|
||||
# key,
|
||||
# )
|
||||
# raise ValueError(f"Unrecognized datatype for {key}.")
|
||||
# # Dataype is a string/varchar equivalent
|
||||
# elif dtype == DataType.VARCHAR:
|
||||
# fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
|
||||
# else:
|
||||
# fields.append(FieldSchema(key, dtype))
|
||||
if metadatas:
|
||||
fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
|
||||
|
||||
# Create the text field
|
||||
fields.append(
|
||||
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
|
||||
)
|
||||
# Create the primary key field
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
self._primary_field, DataType.INT64, is_primary=True, auto_id=True
|
||||
)
|
||||
)
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(
|
||||
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||
)
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(fields)
|
||||
|
||||
# Create the collection
|
||||
try:
|
||||
self.col = Collection(
|
||||
name=self.collection_name,
|
||||
schema=schema,
|
||||
consistency_level=self.consistency_level,
|
||||
using=self.alias,
|
||||
)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to create collection: %s error: %s", self.collection_name, e
|
||||
)
|
||||
raise e
|
||||
|
||||
def _extract_fields(self) -> None:
|
||||
"""Grab the existing fields from the Collection"""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection):
|
||||
schema = self.col.schema
|
||||
for x in schema.fields:
|
||||
self.fields.append(x.name)
|
||||
# Since primary field is auto-id, no need to track it
|
||||
self.fields.remove(self._primary_field)
|
||||
|
||||
def _get_index(self) -> Optional[dict[str, Any]]:
|
||||
"""Return the vector index information if it exists"""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection):
|
||||
for x in self.col.indexes:
|
||||
if x.field_name == self._vector_field:
|
||||
return x.to_dict()
|
||||
return None
|
||||
|
||||
def _create_index(self) -> None:
|
||||
"""Create a index on the collection"""
|
||||
from pymilvus import Collection, MilvusException
|
||||
|
||||
if isinstance(self.col, Collection) and self._get_index() is None:
|
||||
try:
|
||||
# If no index params, use a default HNSW based one
|
||||
if self.index_params is None:
|
||||
self.index_params = {
|
||||
"metric_type": "IP",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
|
||||
try:
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
|
||||
# If default did not work, most likely on Zilliz Cloud
|
||||
except MilvusException:
|
||||
# Use AUTOINDEX based index
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "AUTOINDEX",
|
||||
"params": {},
|
||||
}
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
logger.debug(
|
||||
"Successfully created an index on collection: %s",
|
||||
self.collection_name,
|
||||
)
|
||||
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to create an index on collection: %s", self.collection_name
|
||||
)
|
||||
raise e
|
||||
|
||||
def _create_search_params(self) -> None:
|
||||
"""Generate search params based on the current index type"""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection) and self.search_params is None:
|
||||
index = self._get_index()
|
||||
if index is not None:
|
||||
index_type: str = index["index_param"]["index_type"]
|
||||
metric_type: str = index["index_param"]["metric_type"]
|
||||
self.search_params = self.default_search_params[index_type]
|
||||
self.search_params["metric_type"] = metric_type
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Load the collection if available."""
|
||||
from pymilvus import Collection
|
||||
|
||||
if isinstance(self.col, Collection) and self._get_index() is not None:
|
||||
self.col.load()
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
batch_size: int = 1000,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Insert text data into Milvus.
|
||||
|
||||
Inserting data when the collection has not be made yet will result
|
||||
in creating a new Collection. The data of the first entity decides
|
||||
the schema of the new collection, the dim is extracted from the first
|
||||
embedding and the columns are decided by the first metadata dict.
|
||||
Metada keys will need to be present for all inserted values. At
|
||||
the moment there is no None equivalent in Milvus.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): The texts to embed, it is assumed
|
||||
that they all fit in memory.
|
||||
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
|
||||
the texts. Defaults to None.
|
||||
timeout (Optional[int]): Timeout for each batch insert. Defaults
|
||||
to None.
|
||||
batch_size (int, optional): Batch size to use for insertion.
|
||||
Defaults to 1000.
|
||||
|
||||
Raises:
|
||||
MilvusException: Failure to add texts
|
||||
|
||||
Returns:
|
||||
List[str]: The resulting keys for each inserted element.
|
||||
"""
|
||||
from pymilvus import Collection, MilvusException
|
||||
|
||||
texts = list(texts)
|
||||
|
||||
try:
|
||||
embeddings = self.embedding_func.embed_documents(texts)
|
||||
except NotImplementedError:
|
||||
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
||||
|
||||
if len(embeddings) == 0:
|
||||
logger.debug("Nothing to insert, skipping.")
|
||||
return []
|
||||
|
||||
# If the collection hasn't been initialized yet, perform all steps to do so
|
||||
if not isinstance(self.col, Collection):
|
||||
self._init(embeddings, metadatas)
|
||||
|
||||
# Dict to hold all insert columns
|
||||
insert_dict: dict[str, list] = {
|
||||
self._text_field: texts,
|
||||
self._vector_field: embeddings,
|
||||
}
|
||||
|
||||
# Collect the metadata into the insert dict.
|
||||
# if metadatas is not None:
|
||||
# for d in metadatas:
|
||||
# for key, value in d.items():
|
||||
# if key in self.fields:
|
||||
# insert_dict.setdefault(key, []).append(value)
|
||||
if metadatas is not None:
|
||||
for d in metadatas:
|
||||
insert_dict.setdefault(self._metadata_field, []).append(d)
|
||||
|
||||
# Total insert count
|
||||
vectors: list = insert_dict[self._vector_field]
|
||||
total_count = len(vectors)
|
||||
|
||||
pks: list[str] = []
|
||||
|
||||
assert isinstance(self.col, Collection)
|
||||
for i in range(0, total_count, batch_size):
|
||||
# Grab end index
|
||||
end = min(i + batch_size, total_count)
|
||||
# Convert dict to list of lists batch for insertion
|
||||
insert_list = [insert_dict[x][i:end] for x in self.fields]
|
||||
# Insert into the collection.
|
||||
try:
|
||||
res: Collection
|
||||
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
|
||||
pks.extend(res.primary_keys)
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to insert batch starting at entity: %s/%s", i, total_count
|
||||
)
|
||||
raise e
|
||||
return pks
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Perform a similarity search against the query string.
|
||||
|
||||
Args:
|
||||
query (str): The text to search.
|
||||
k (int, optional): How many results to return. Defaults to 4.
|
||||
param (dict, optional): The search params for the index type.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
res = self.similarity_search_with_score(
|
||||
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in res]
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Perform a similarity search against the query string.
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector to search.
|
||||
k (int, optional): How many results to return. Defaults to 4.
|
||||
param (dict, optional): The search params for the index type.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
res = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in res]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the pymilvus
|
||||
documentation found here:
|
||||
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||
|
||||
Args:
|
||||
query (str): The text being searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[float], List[Tuple[Document, any, any]]:
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
# Embed the query text.
|
||||
embedding = self.embedding_func.embed_query(query)
|
||||
|
||||
res = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
|
||||
)
|
||||
return res
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
return self.similarity_search_with_score(query, k, **kwargs)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the pymilvus
|
||||
documentation found here:
|
||||
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector being searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: Result doc and score.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
if param is None:
|
||||
param = self.search_params
|
||||
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self._vector_field)
|
||||
|
||||
# Perform the search.
|
||||
res = self.col.search(
|
||||
data=[embedding],
|
||||
anns_field=self._vector_field,
|
||||
param=param,
|
||||
limit=k,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
# Organize results.
|
||||
ret = []
|
||||
for result in res[0]:
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
|
||||
pair = (doc, result.score)
|
||||
ret.append(pair)
|
||||
|
||||
return ret
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Perform a search and return results that are reordered by MMR.
|
||||
|
||||
Args:
|
||||
query (str): The text being searched.
|
||||
k (int, optional): How many results to give. Defaults to 4.
|
||||
fetch_k (int, optional): Total results to select k from.
|
||||
Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5
|
||||
param (dict, optional): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
embedding = self.embedding_func.embed_query(query)
|
||||
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
param=param,
|
||||
expr=expr,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Perform a search and return results that are reordered by MMR.
|
||||
|
||||
Args:
|
||||
embedding (str): The embedding vector being searched.
|
||||
k (int, optional): How many results to give. Defaults to 4.
|
||||
fetch_k (int, optional): Total results to select k from.
|
||||
Defaults to 20.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5
|
||||
param (dict, optional): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document]: Document results for search.
|
||||
"""
|
||||
if self.col is None:
|
||||
logger.debug("No existing collection to search.")
|
||||
return []
|
||||
|
||||
if param is None:
|
||||
param = self.search_params
|
||||
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self._vector_field)
|
||||
|
||||
# Perform the search.
|
||||
res = self.col.search(
|
||||
data=[embedding],
|
||||
anns_field=self._vector_field,
|
||||
param=param,
|
||||
limit=fetch_k,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
# Organize results.
|
||||
ids = []
|
||||
documents = []
|
||||
scores = []
|
||||
for result in res[0]:
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
|
||||
documents.append(doc)
|
||||
scores.append(result.score)
|
||||
ids.append(result.id)
|
||||
|
||||
vectors = self.col.query(
|
||||
expr=f"{self._primary_field} in {ids}",
|
||||
output_fields=[self._primary_field, self._vector_field],
|
||||
timeout=timeout,
|
||||
)
|
||||
# Reorganize the results from query to match search order.
|
||||
vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
|
||||
|
||||
ordered_result_embeddings = [vectors[x] for x in ids]
|
||||
|
||||
# Get the new order of results.
|
||||
new_ordering = maximal_marginal_relevance(
|
||||
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
|
||||
)
|
||||
|
||||
# Reorder the values and return.
|
||||
ret = []
|
||||
for x in new_ordering:
|
||||
# Function can return -1 index
|
||||
if x == -1:
|
||||
break
|
||||
else:
|
||||
ret.append(documents[x])
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
|
||||
consistency_level: str = "Session",
|
||||
index_params: Optional[dict] = None,
|
||||
search_params: Optional[dict] = None,
|
||||
drop_old: bool = False,
|
||||
batch_size: int = 100,
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Milvus:
|
||||
"""Create a Milvus collection, indexes it with HNSW, and insert data.
|
||||
|
||||
Args:
|
||||
texts (List[str]): Text data.
|
||||
embedding (Embeddings): Embedding function.
|
||||
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
|
||||
Defaults to None.
|
||||
collection_name (str, optional): Collection name to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (dict[str, Any], optional): Connection args to use. Defaults
|
||||
to DEFAULT_MILVUS_CONNECTION.
|
||||
consistency_level (str, optional): Which consistency level to use. Defaults
|
||||
to "Session".
|
||||
index_params (Optional[dict], optional): Which index_params to use. Defaults
|
||||
to None.
|
||||
search_params (Optional[dict], optional): Which search params to use.
|
||||
Defaults to None.
|
||||
drop_old (Optional[bool], optional): Whether to drop the collection with
|
||||
that name if it exists. Defaults to False.
|
||||
batch_size:
|
||||
How many vectors upload per-request.
|
||||
Default: 100
|
||||
ids: Optional[Sequence[str]] = None,
|
||||
|
||||
Returns:
|
||||
Milvus: Milvus Vector Store
|
||||
"""
|
||||
vector_db = cls(
|
||||
embedding_function=embedding,
|
||||
collection_name=collection_name,
|
||||
connection_args=connection_args,
|
||||
consistency_level=consistency_level,
|
||||
index_params=index_params,
|
||||
search_params=search_params,
|
||||
drop_old=drop_old,
|
||||
**kwargs,
|
||||
)
|
||||
vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
|
||||
return vector_db
|
||||
@ -9,30 +9,44 @@ from core.index.base import BaseIndex
|
||||
from core.index.vector_index.base import BaseVectorIndex
|
||||
from core.vector_store.milvus_vector_store import MilvusVectorStore
|
||||
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
|
||||
from models.dataset import Dataset
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetCollectionBinding
|
||||
|
||||
|
||||
class MilvusConfig(BaseModel):
|
||||
endpoint: str
|
||||
host: str
|
||||
port: int
|
||||
user: str
|
||||
password: str
|
||||
secure: bool = False
|
||||
batch_size: int = 100
|
||||
|
||||
@root_validator()
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values['endpoint']:
|
||||
raise ValueError("config MILVUS_ENDPOINT is required")
|
||||
if not values['host']:
|
||||
raise ValueError("config MILVUS_HOST is required")
|
||||
if not values['port']:
|
||||
raise ValueError("config MILVUS_PORT is required")
|
||||
if not values['user']:
|
||||
raise ValueError("config MILVUS_USER is required")
|
||||
if not values['password']:
|
||||
raise ValueError("config MILVUS_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_milvus_params(self):
|
||||
return {
|
||||
'host': self.host,
|
||||
'port': self.port,
|
||||
'user': self.user,
|
||||
'password': self.password,
|
||||
'secure': self.secure
|
||||
}
|
||||
|
||||
|
||||
class MilvusVectorIndex(BaseVectorIndex):
|
||||
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
|
||||
super().__init__(dataset, embeddings)
|
||||
self._client = self._init_client(config)
|
||||
self._client_config = config
|
||||
|
||||
def get_type(self) -> str:
|
||||
return 'milvus'
|
||||
@ -49,7 +63,6 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
dataset_id = dataset.id
|
||||
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
|
||||
def to_index_struct(self) -> dict:
|
||||
return {
|
||||
"type": self.get_type(),
|
||||
@ -58,26 +71,29 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
index_params = {
|
||||
'metric_type': 'IP',
|
||||
'index_type': "HNSW",
|
||||
'params': {"M": 8, "efConstruction": 64}
|
||||
}
|
||||
self._vector_store = MilvusVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
connection_args=self._client_config.to_milvus_params(),
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
|
||||
uuids = self._get_uuids(texts)
|
||||
self._vector_store = WeaviateVectorStore.from_documents(
|
||||
self._vector_store = MilvusVectorStore.from_documents(
|
||||
texts,
|
||||
self._embeddings,
|
||||
client=self._client,
|
||||
index_name=collection_name,
|
||||
uuids=uuids,
|
||||
by_text=False
|
||||
collection_name=collection_name,
|
||||
ids=uuids,
|
||||
content_payload_key='page_content'
|
||||
)
|
||||
|
||||
return self
|
||||
@ -86,42 +102,53 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
"""Only for created index."""
|
||||
if self._vector_store:
|
||||
return self._vector_store
|
||||
|
||||
attributes = ['doc_id', 'dataset_id', 'document_id']
|
||||
if self._is_origin():
|
||||
attributes = ['doc_id']
|
||||
|
||||
return WeaviateVectorStore(
|
||||
client=self._client,
|
||||
index_name=self.get_index_name(self.dataset),
|
||||
text_key='text',
|
||||
embedding=self._embeddings,
|
||||
attributes=attributes,
|
||||
by_text=False
|
||||
return MilvusVectorStore(
|
||||
collection_name=self.get_index_name(self.dataset),
|
||||
embedding_function=self._embeddings,
|
||||
connection_args=self._client_config.to_milvus_params()
|
||||
)
|
||||
|
||||
def _get_vector_store_class(self) -> type:
|
||||
return MilvusVectorStore
|
||||
|
||||
def delete_by_document_id(self, document_id: str):
|
||||
if self._is_origin():
|
||||
self.recreate_dataset(self.dataset)
|
||||
return
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_document_id(document_id)
|
||||
if ids:
|
||||
vector_store.del_texts({
|
||||
'filter': f'id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
ids = vector_store.get_ids_by_doc_ids(doc_ids)
|
||||
vector_store.del_texts({
|
||||
'filter': f' id in {ids}'
|
||||
})
|
||||
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.del_texts({
|
||||
"operator": "Equal",
|
||||
"path": ["document_id"],
|
||||
"valueText": document_id
|
||||
})
|
||||
vector_store.delete()
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
|
||||
if not class_prefix.endswith('_Node'):
|
||||
# original class_prefix
|
||||
return True
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
return False
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
@ -47,6 +47,20 @@ class VectorIndex:
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
elif vector_type == "milvus":
|
||||
from core.index.vector_index.milvus_vector_index import MilvusVectorIndex, MilvusConfig
|
||||
|
||||
return MilvusVectorIndex(
|
||||
dataset=dataset,
|
||||
config=MilvusConfig(
|
||||
host=config.get('MILVUS_HOST'),
|
||||
port=config.get('MILVUS_PORT'),
|
||||
user=config.get('MILVUS_USER'),
|
||||
password=config.get('MILVUS_PASSWORD'),
|
||||
secure=config.get('MILVUS_SECURE'),
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ from flask import current_app, Flask
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.data_loader.loader.notion import NotionLoader
|
||||
@ -79,6 +80,8 @@ class IndexingRunner:
|
||||
dataset_document.error = str(e.description)
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
except ObjectDeletedError:
|
||||
logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
|
||||
except Exception as e:
|
||||
logging.exception("consume document failed")
|
||||
dataset_document.indexing_status = 'error'
|
||||
@ -86,22 +89,6 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
|
||||
matches = re.findall(regex, text, re.MULTILINE)
|
||||
|
||||
result = []
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
try:
|
||||
@ -276,13 +263,14 @@ class IndexingRunner:
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@ -372,13 +360,14 @@ class IndexingRunner:
|
||||
)
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
|
||||
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
@ -582,7 +571,6 @@ class IndexingRunner:
|
||||
|
||||
all_qa_documents.extend(format_documents)
|
||||
|
||||
|
||||
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
|
||||
processing_rule: DatasetProcessRule) -> List[Document]:
|
||||
"""
|
||||
@ -643,21 +631,16 @@ class IndexingRunner:
|
||||
return text
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
|
||||
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
|
||||
matches = re.findall(regex, text, re.MULTILINE)
|
||||
|
||||
result = [] # 存储最终的结果
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
# 如果Q和A都存在,就将其添加到结果中
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
return [
|
||||
{
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
}
|
||||
for q, a in matches if q and a
|
||||
]
|
||||
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
|
||||
"""
|
||||
@ -734,6 +717,9 @@ class IndexingRunner:
|
||||
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
||||
if count > 0:
|
||||
raise DocumentIsPausedException()
|
||||
document = DatasetDocument.query.filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise DocumentIsDeletedPausedException()
|
||||
|
||||
update_params = {
|
||||
DatasetDocument.indexing_status: after_indexing_status
|
||||
@ -781,3 +767,7 @@ class IndexingRunner:
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsDeletedPausedException(Exception):
|
||||
pass
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, List, Dict
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage
|
||||
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from extensions.ext_database import db
|
||||
@ -21,6 +22,8 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
app_model = self.conversation.app
|
||||
|
||||
# fetch limited messages desc, and return reversed
|
||||
messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
@ -28,10 +31,25 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
|
||||
|
||||
chat_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
|
||||
files = message.message_files
|
||||
if files:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files, message.app_model_config
|
||||
)
|
||||
|
||||
prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
|
||||
chat_messages.append(PromptMessage(
|
||||
content=message.query,
|
||||
type=MessageType.USER,
|
||||
files=prompt_message_files
|
||||
))
|
||||
else:
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
|
||||
|
||||
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
|
||||
|
||||
if not chat_messages:
|
||||
|
||||
@ -211,6 +211,9 @@ class ModelProviderFactory:
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value
|
||||
).first()
|
||||
|
||||
if provider.quota_limit == 0:
|
||||
return None
|
||||
|
||||
return provider
|
||||
|
||||
no_system_provider = True
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
|
||||
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import enum
|
||||
from typing import Any, cast, Union, List, Dict
|
||||
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -9,26 +10,67 @@ class LLMRunResult(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
source: list = None
|
||||
function_call: dict = None
|
||||
|
||||
|
||||
class MessageType(enum.Enum):
|
||||
HUMAN = 'human'
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
SYSTEM = 'system'
|
||||
|
||||
|
||||
class PromptMessageFileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in PromptMessageFileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
|
||||
class PromptMessageFile(BaseModel):
|
||||
type: PromptMessageFileType
|
||||
data: Any
|
||||
|
||||
|
||||
class ImagePromptMessageFile(PromptMessageFile):
|
||||
class DETAIL(enum.Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
|
||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
type: MessageType = MessageType.HUMAN
|
||||
type: MessageType = MessageType.USER
|
||||
content: str = ''
|
||||
files: list[PromptMessageFile] = []
|
||||
function_call: dict = None
|
||||
|
||||
|
||||
class LCHumanMessageWithFiles(HumanMessage):
|
||||
# content: Union[str, List[Union[str, Dict]]]
|
||||
content: str
|
||||
files: list[PromptMessageFile]
|
||||
|
||||
|
||||
def to_lc_messages(messages: list[PromptMessage]):
|
||||
lc_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
if message.type == MessageType.USER:
|
||||
if not message.files:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
else:
|
||||
lc_messages.append(LCHumanMessageWithFiles(content=message.content, files=message.files))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
lc_messages.append(AIMessage(content=message.content))
|
||||
additional_kwargs = {}
|
||||
if message.function_call:
|
||||
additional_kwargs['function_call'] = message.function_call
|
||||
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
lc_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
@ -39,11 +81,28 @@ def to_prompt_messages(messages: list[BaseMessage]):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
|
||||
if isinstance(message, LCHumanMessageWithFiles):
|
||||
prompt_messages.append(PromptMessage(
|
||||
content=message.content,
|
||||
type=MessageType.USER,
|
||||
files=message.files
|
||||
))
|
||||
else:
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
elif isinstance(message, AIMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
|
||||
message_kwargs = {
|
||||
'content': message.content,
|
||||
'type': MessageType.ASSISTANT
|
||||
}
|
||||
|
||||
if 'function_call' in message.additional_kwargs:
|
||||
message_kwargs['function_call'] = message.additional_kwargs['function_call']
|
||||
|
||||
prompt_messages.append(PromptMessage(**message_kwargs))
|
||||
elif isinstance(message, SystemMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
|
||||
elif isinstance(message, FunctionMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
return prompt_messages
|
||||
|
||||
|
||||
|
||||
@ -8,3 +8,4 @@ class ProviderQuotaUnit(Enum):
|
||||
|
||||
class ModelFeature(Enum):
|
||||
AGENT_THOUGHT = 'agent_thought'
|
||||
VISION = 'vision'
|
||||
|
||||
@ -19,6 +19,13 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
FUNCTION_CALL_MODELS = [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-35-turbo',
|
||||
'gpt-35-turbo-16k'
|
||||
]
|
||||
|
||||
class AzureOpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
@ -81,7 +88,20 @@ class AzureOpenAIModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
generate_kwargs = {
|
||||
'stop': stop,
|
||||
'callbacks': callbacks
|
||||
}
|
||||
|
||||
if isinstance(prompts, str):
|
||||
generate_kwargs['prompts'] = [prompts]
|
||||
else:
|
||||
generate_kwargs['messages'] = [prompts]
|
||||
|
||||
if 'functions' in kwargs:
|
||||
generate_kwargs['functions'] = kwargs['functions']
|
||||
|
||||
return self._client.generate(**generate_kwargs)
|
||||
|
||||
@property
|
||||
def base_model_name(self) -> str:
|
||||
@ -144,3 +164,7 @@ class AzureOpenAIModel(BaseLLM):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def support_function_call(self):
|
||||
return self.base_model_name in FUNCTION_CALL_MODELS
|
||||
|
||||
@ -37,12 +37,6 @@ class BaichuanModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
@ -1,27 +1,18 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Any, Union, Tuple
|
||||
from typing import List, Optional, Any, Union
|
||||
import decimal
|
||||
import logging
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
|
||||
from langchain.schema import LLMResult, BaseMessage, ChatGeneration
|
||||
|
||||
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
|
||||
from core.helper import moderation
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_lc_messages
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import JinjaPromptTemplate
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
import logging
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -157,8 +148,11 @@ class BaseLLM(BaseProviderModel):
|
||||
except Exception as ex:
|
||||
raise self.handle_exceptions(ex)
|
||||
|
||||
function_call = None
|
||||
if isinstance(result.generations[0][0], ChatGeneration):
|
||||
completion_content = result.generations[0][0].message.content
|
||||
if 'function_call' in result.generations[0][0].message.additional_kwargs:
|
||||
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
|
||||
else:
|
||||
completion_content = result.generations[0][0].text
|
||||
|
||||
@ -191,7 +185,8 @@ class BaseLLM(BaseProviderModel):
|
||||
return LLMRunResult(
|
||||
content=completion_content,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens
|
||||
completion_tokens=completion_tokens,
|
||||
function_call=function_call
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@ -227,7 +222,7 @@ class BaseLLM(BaseProviderModel):
|
||||
:param message_type:
|
||||
:return:
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
@ -245,7 +240,7 @@ class BaseLLM(BaseProviderModel):
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.0001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
|
||||
unit_price = self.price_config['prompt']
|
||||
else:
|
||||
unit_price = self.price_config['completion']
|
||||
@ -260,7 +255,7 @@ class BaseLLM(BaseProviderModel):
|
||||
:param message_type:
|
||||
:return: decimal.Decimal('0.000001')
|
||||
"""
|
||||
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
|
||||
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
|
||||
price_unit = self.price_config['unit']
|
||||
else:
|
||||
price_unit = self.price_config['unit']
|
||||
@ -315,121 +310,12 @@ class BaseLLM(BaseProviderModel):
|
||||
def support_streaming(self):
|
||||
return False
|
||||
|
||||
def get_prompt(self, mode: str,
|
||||
pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory]) -> \
|
||||
Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
|
||||
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
|
||||
return [PromptMessage(content=prompt)], stops
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'common_completion'
|
||||
else:
|
||||
return 'common_chat'
|
||||
|
||||
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
|
||||
context_prompt_content = ''
|
||||
if context and 'context_prompt' in prompt_rules:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
|
||||
context_prompt_content = prompt_template.format(
|
||||
context=context
|
||||
)
|
||||
|
||||
pre_prompt_content = ''
|
||||
if pre_prompt:
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
||||
pre_prompt_content = prompt_template.format(
|
||||
**prompt_inputs
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
if order == 'context_prompt':
|
||||
prompt += context_prompt_content
|
||||
elif order == 'pre_prompt':
|
||||
prompt += pre_prompt_content
|
||||
|
||||
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
|
||||
|
||||
if memory and 'histories_prompt' in prompt_rules:
|
||||
# append chat histories
|
||||
tmp_human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=prompt + query_prompt,
|
||||
inputs={
|
||||
'query': query
|
||||
}
|
||||
)
|
||||
|
||||
if self.model_rules.max_tokens.max:
|
||||
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
|
||||
max_tokens = self.model_kwargs.max_tokens
|
||||
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
else:
|
||||
rest_tokens = 2000
|
||||
|
||||
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
|
||||
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
|
||||
histories = self._get_history_messages_from_memory(memory, rest_tokens)
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
|
||||
histories_prompt_content = prompt_template.format(
|
||||
histories=histories
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
if order == 'context_prompt':
|
||||
prompt += context_prompt_content
|
||||
elif order == 'pre_prompt':
|
||||
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
|
||||
elif order == 'histories_prompt':
|
||||
prompt += histories_prompt_content
|
||||
|
||||
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
|
||||
query_prompt_content = prompt_template.format(
|
||||
query=query
|
||||
)
|
||||
|
||||
prompt += query_prompt_content
|
||||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
stops = prompt_rules.get('stops')
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
|
||||
return prompt, stops
|
||||
|
||||
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
|
||||
'prompt/generate_prompts')
|
||||
|
||||
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
|
||||
# Open the JSON file and read its content
|
||||
with open(json_file_path, 'r') as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
|
||||
max_token_limit: int) -> str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
external_context = memory.load_memory_variables({})
|
||||
return external_context[memory_key]
|
||||
@property
|
||||
def support_function_call(self):
|
||||
return False
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
|
||||
if not model_mode:
|
||||
model_mode = self.model_mode
|
||||
|
||||
@ -442,16 +328,7 @@ class BaseLLM(BaseProviderModel):
|
||||
if len(messages) == 0:
|
||||
return []
|
||||
|
||||
chat_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.HUMAN:
|
||||
chat_messages.append(HumanMessage(content=message.content))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
chat_messages.append(AIMessage(content=message.content))
|
||||
elif message.type == MessageType.SYSTEM:
|
||||
chat_messages.append(SystemMessage(content=message.content))
|
||||
|
||||
return chat_messages
|
||||
return to_lc_messages(messages)
|
||||
|
||||
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
|
||||
"""
|
||||
|
||||
@ -66,15 +66,6 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.get_num_tokens(prompts)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if 'baichuan' in self.name.lower():
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
else:
|
||||
return super().prompt_file_name(mode)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
self.client.model_kwargs = provider_model_kwargs
|
||||
|
||||
@ -1,26 +1,23 @@
|
||||
import decimal
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import Minimax
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
|
||||
|
||||
|
||||
class MinimaxModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return Minimax(
|
||||
return MinimaxChatLLM(
|
||||
model=self.name,
|
||||
model_kwargs={
|
||||
'stream': False
|
||||
},
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
|
||||
return LLMBadRequestError(f"Minimax: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import decimal
|
||||
import logging
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
from openai import api_requestor
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
@ -23,21 +21,36 @@ COMPLETION_MODELS = [
|
||||
]
|
||||
|
||||
CHAT_MODELS = [
|
||||
'gpt-4-1106-preview', # 128,000 tokens
|
||||
'gpt-4-vision-preview', # 128,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo-1106', # 16,384 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
]
|
||||
|
||||
MODEL_MAX_TOKENS = {
|
||||
'gpt-4-1106-preview': 128000,
|
||||
'gpt-4-vision-preview': 128000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo-1106': 16384,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 8192,
|
||||
'gpt-3.5-turbo-instruct': 4097,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
FUNCTION_CALL_MODELS = [
|
||||
'gpt-4-1106-preview',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo-1106',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k'
|
||||
]
|
||||
|
||||
|
||||
class OpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
@ -50,7 +63,6 @@ class OpenAIModel(BaseLLM):
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
# TODO load price config from configs(db)
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@ -100,13 +112,27 @@ class OpenAIModel(BaseLLM):
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if self.name == 'gpt-4' \
|
||||
if self.name.startswith('gpt-4') \
|
||||
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
generate_kwargs = {
|
||||
'stop': stop,
|
||||
'callbacks': callbacks
|
||||
}
|
||||
|
||||
if isinstance(prompts, str):
|
||||
generate_kwargs['prompts'] = [prompts]
|
||||
else:
|
||||
generate_kwargs['messages'] = [prompts]
|
||||
|
||||
if 'functions' in kwargs:
|
||||
generate_kwargs['functions'] = kwargs['functions']
|
||||
|
||||
return self._client.generate(**generate_kwargs)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
@ -161,6 +187,10 @@ class OpenAIModel(BaseLLM):
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def support_function_call(self):
|
||||
return self.name in FUNCTION_CALL_MODELS
|
||||
|
||||
# def is_model_valid_or_raise(self):
|
||||
# """
|
||||
# check is a valid model.
|
||||
|
||||
@ -49,15 +49,6 @@ class OpenLLMModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if 'baichuan' in self.name.lower():
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
else:
|
||||
return super().prompt_file_name(mode)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ class TongyiModel(BaseLLM):
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
del provider_model_kwargs['max_tokens']
|
||||
return EnhanceTongyi(
|
||||
model_name=self.name,
|
||||
max_retries=1,
|
||||
@ -58,7 +57,6 @@ class TongyiModel(BaseLLM):
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
del provider_model_kwargs['max_tokens']
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
@ -6,17 +6,16 @@ from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.wenxin import Wenxin
|
||||
|
||||
|
||||
class WenxinModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
# TODO load price_config from configs(db)
|
||||
return Wenxin(
|
||||
model=self.name,
|
||||
streaming=self.streaming,
|
||||
@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}
|
||||
|
||||
if 'functions' in kwargs:
|
||||
generate_kwargs['functions'] = kwargs['functions']
|
||||
|
||||
return self._client.generate(**generate_kwargs)
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Wenxin: {str(ex)}")
|
||||
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@ -59,15 +59,6 @@ class XinferenceModel(BaseLLM):
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if 'baichuan' in self.name.lower():
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
else:
|
||||
return super().prompt_file_name(mode)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
pass
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ class ZhipuAIModel(BaseLLM):
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ZhipuAIChatLLM(
|
||||
model=self.name,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
|
||||
@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.anthropic_model import AnthropicModel
|
||||
from core.model_providers.models.llm.base import ModelType
|
||||
@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'claude-instant-1',
|
||||
'name': 'claude-instant-1',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'claude-2',
|
||||
'name': 'claude-2',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
@ -167,7 +172,7 @@ class AnthropicProvider(BaseModelProvider):
|
||||
|
||||
def should_deduct_quota(self):
|
||||
if hosted_model_providers.anthropic and \
|
||||
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > 0:
|
||||
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > -1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@ -12,7 +12,7 @@ from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
|
||||
AZURE_OPENAI_API_VERSION
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
|
||||
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
|
||||
from core.model_providers.models.entity.provider import ModelFeature
|
||||
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
credentials = json.loads(provider_model.encrypted_config)
|
||||
|
||||
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
|
||||
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
|
||||
|
||||
if credentials['base_model_name'] in [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
|
||||
return model_list
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
if model_name == 'text-davinci-003':
|
||||
return ModelMode.COMPLETION.value
|
||||
else:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
models = [
|
||||
{
|
||||
'id': 'gpt-3.5-turbo',
|
||||
'name': 'gpt-3.5-turbo',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-16k',
|
||||
'name': 'gpt-3.5-turbo-16k',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-4',
|
||||
'name': 'gpt-4',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'text-davinci-003',
|
||||
'name': 'text-davinci-003',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
|
||||
@ -314,7 +329,7 @@ class AzureOpenAIProvider(BaseModelProvider):
|
||||
|
||||
def should_deduct_quota(self):
|
||||
if hosted_model_providers.azure_openai \
|
||||
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > 0:
|
||||
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > -1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.baichuan_model import BaichuanModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
|
||||
@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider):
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'baichuan'
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'baichuan2-53b',
|
||||
'name': 'Baichuan2-53B',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
|
||||
@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
|
||||
ProviderModel.is_valid == True
|
||||
).order_by(ProviderModel.created_at.asc()).all()
|
||||
|
||||
return [{
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
} for provider_model in provider_models]
|
||||
provider_model_list = []
|
||||
for provider_model in provider_models:
|
||||
provider_model_dict = {
|
||||
'id': provider_model.model_name,
|
||||
'name': provider_model.model_name
|
||||
}
|
||||
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
|
||||
|
||||
provider_model_list.append(provider_model_dict)
|
||||
|
||||
return provider_model_list
|
||||
|
||||
@abstractmethod
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
"""
|
||||
get text generation model mode.
|
||||
|
||||
:param model_name:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model_class(self, model_type: ModelType) -> Type:
|
||||
"""
|
||||
|
||||
@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from models.provider import ProviderType
|
||||
@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
{
|
||||
'id': 'chatglm2-6b',
|
||||
'name': 'ChatGLM2-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm-6b',
|
||||
'name': 'ChatGLM-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user