Compare commits

...

51 Commits

Author SHA1 Message Date
5bffa1d918 feat: bump version to 0.3.24 (#1262) 2023-09-28 18:32:06 +08:00
c9b0fe47bf Fix/notion sync (#1258) 2023-09-28 14:39:13 +08:00
bcd744b6b7 fix: doc (#1256) 2023-09-28 11:26:04 +08:00
5e511e01bf Fix/dataset api key delete (#1255)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-28 10:41:41 +08:00
52291c645e fix: dataset footer styles (#1254) 2023-09-28 10:06:52 +08:00
a31466d34e fix: db session not commit before long llm call running (#1251) 2023-09-27 21:40:26 +08:00
d38eac959b fix: wenxin model name invalid when llm call (#1248) 2023-09-27 16:29:13 +08:00
9dbb8acd4b Feat/dataset support api service (#1240)
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
2023-09-27 16:06:49 +08:00
46154c6705 Feat/dataset service api (#1245)
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-09-27 16:06:32 +08:00
54ff03c35d fix: dataset query error. (#1244) 2023-09-27 15:24:54 +08:00
18c710c906 feat: support binding context var (#1227)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-09-27 14:53:22 +08:00
59236b789f Fix: dataset list refresh (#1216) 2023-09-27 10:31:46 +08:00
fd3d43cae1 Fix: debounce of dataset creation (#1237) 2023-09-27 10:31:27 +08:00
8eae643911 Fix App logs page modal show different model icon. (#1224) 2023-09-27 08:54:52 +08:00
fd9413874a fix: FATAL: role "root" does not exist. (#1233) 2023-09-26 10:20:00 +08:00
227f9fb77d Feat/api jwt (#1212) 2023-09-25 12:49:16 +08:00
c40ee7e629 feat: batch run support retry errors and decrease rate limit times (#1215) 2023-09-25 10:20:50 +08:00
841e967d48 Fix: add loading for dataset creation (#1214) 2023-09-24 01:35:20 -05:00
9df0dcedae fix: dataset eslint error (#1221) 2023-09-22 22:38:33 +08:00
724e053732 Fix/qdrant data issue (#1203)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-22 14:21:26 +08:00
e409895c02 Feat/huggingface embedding support (#1211)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-09-22 13:59:02 +08:00
32d9b6181c fix: transaction not commit during long LLM calls (#1213) 2023-09-22 12:43:06 +08:00
2b018fade2 fix: transaction hangs due to message commit block during long LLM calls (#1206) 2023-09-21 11:22:10 +08:00
e65f9cb17a Complete type defined. (#1200) 2023-09-19 23:27:06 -05:00
1367f34398 fix: provider spark free quota text (#1201) 2023-09-20 11:46:25 +08:00
e47f6b879a add help wanted issue template (#1199) 2023-09-19 20:02:41 -05:00
5809edd74b feat: bump version to 0.3.23 (#1198) 2023-09-20 00:14:36 +08:00
05bfa11915 build: update devDependencies (#1125) 2023-09-19 13:31:48 +08:00
435f804c6f fix: gpt-3.5-turbo-instruct context size to 8192 (#1196) 2023-09-19 02:10:22 +08:00
ae3f1ac0a9 feat: support gpt-3.5-turbo-instruct model (#1195) 2023-09-19 02:05:04 +08:00
269a465fc4 Feat/improve vector database logic (#1193)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-18 18:15:41 +08:00
60e0bbd713 Feat/provider add zhipuai (#1192)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-09-18 18:02:05 +08:00
827c97f0d3 feat: add zhipuai (#1188) 2023-09-18 17:32:31 +08:00
c8bd76cd66 fix: inference embedding validate (#1187) 2023-09-16 03:09:36 +08:00
ec5f585df4 1111 wrong embedding model displayed in datasets (#1186) 2023-09-15 07:54:45 -05:00
1de48f33ca feat(web): service request return generics type (#1157) 2023-09-15 07:54:20 -05:00
6b41a9593e fix: text error (#1184) 2023-09-15 14:15:28 +08:00
82267083e8 fix: model param description error (#1183) 2023-09-15 11:36:01 +08:00
c385961d33 chore: Optimization model parameter description (#1181) 2023-09-15 11:14:14 +08:00
20bab6edec Restore the application template (#1174)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-09-14 08:28:32 -05:00
67bed54f32 Mermaid front end rendering (#1166)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-09-14 14:09:23 +08:00
leo
562a571281 fix: Improved fallback solution for avatar image loading failure (#1172) 2023-09-14 13:31:35 +08:00
fc68c81791 fix: correct invite url (#1173) 2023-09-14 12:07:34 +08:00
5d9070bc60 Feat/add blocking mode resource return (#1171)
Co-authored-by: jyong <jyong@dify.ai>
2023-09-13 18:53:35 +08:00
b11fb0dfd1 fix LocalAI is missing in lang/en (#1169) 2023-09-13 10:08:33 +08:00
d1c5c5f160 add video to cn readme (#1165) 2023-09-12 08:30:12 -05:00
0b1d1440aa Update README.md (#1164) 2023-09-12 07:48:35 -05:00
0c420d64b3 chore: hover conversation show option button (#1160) 2023-09-12 16:35:13 +08:00
f9082104ed feat: add hosted moderation (#1158) 2023-09-12 10:26:12 +08:00
983834cd52 feat: spark check (#1134) 2023-09-11 17:31:03 +08:00
96d10c8b39 feat: spark free quota verify (#1152) 2023-09-11 17:30:54 +08:00
272 changed files with 9341 additions and 2926 deletions

11
.github/ISSUE_TEMPLATE/help_wanted.yml vendored Normal file
View File

@ -0,0 +1,11 @@
name: "🤝 Help Wanted"
description: "Request help from the community"
labels:
- help-wanted
body:
- type: textarea
attributes:
label: Provide a description of the help you need
placeholder: Briefly describe what you need help with.
validations:
required: true

View File

@ -16,6 +16,10 @@ Out-of-the-box web sites supporting form mode and chat conversation mode
A single API encompassing plugin capabilities, context enhancement, and more, saving you backend coding effort
Visual data analysis, log review, and annotation for applications
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
## Highlighted Features
**1. LLMs support:** Choose capabilities based on different models when building your Dify AI apps. Dify is compatible with Langchain, meaning it will support various LLMs. Currently supported:

View File

@ -17,7 +17,7 @@
- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作
- 可视化的对应用进行数据分析,查阅日志或进行标注
https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e191da4d0f6
## 核心能力
1. **模型支持:** 你可以在 Dify 上选择基于不同模型的能力来开发你的 AI 应用。Dify 兼容 Langchain这意味着我们将逐步支持多种 LLMs ,目前支持的模型供应商:

View File

@ -50,24 +50,6 @@ S3_REGION=your-region
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Cookie configuration
COOKIE_HTTPONLY=true
COOKIE_SAMESITE=None
COOKIE_SECURE=true
# Session configuration
SESSION_PERMANENT=true
SESSION_USE_SIGNER=true
## support redis, sqlalchemy
SESSION_TYPE=redis
# session redis configuration
SESSION_REDIS_HOST=localhost
SESSION_REDIS_PORT=6379
SESSION_REDIS_PASSWORD=difyai123456
SESSION_REDIS_DB=2
# Vector database configuration, support: weaviate, qdrant
VECTOR_STORE=weaviate

View File

@ -1,8 +1,7 @@
# -*- coding:utf-8 -*-
import os
from datetime import datetime, timedelta
from werkzeug.exceptions import Forbidden
from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
@ -12,12 +11,11 @@ import logging
import json
import threading
from flask import Flask, request, Response, session
import flask_login
from flask import Flask, request, Response
from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db
from extensions.ext_login import login_manager
@ -27,12 +25,10 @@ from models import model, account, dataset, web, task, source, tool
from events import event_handlers
# DO NOT REMOVE ABOVE
import core
from config import Config, CloudEditionConfig
from commands import register_commands
from models.account import TenantAccountJoin, AccountStatus
from models.model import Account, EndUser, App
from services.account_service import TenantService
from services.account_service import AccountService
from libs.passport import PassportService
import warnings
warnings.simplefilter("ignore", ResourceWarning)
@ -85,81 +81,33 @@ def initialize_extensions(app):
ext_redis.init_app(app)
ext_storage.init_app(app)
ext_celery.init_app(app)
ext_session.init_app(app)
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)
def _create_tenant_for_account(account):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role='owner')
account.current_tenant = tenant
return tenant
# Flask-Login configuration
@login_manager.user_loader
def load_user(user_id):
"""Load user based on the user_id."""
@login_manager.request_loader
def load_user_from_request(request_from_flask_login):
"""Load user based on the request."""
if request.blueprint == 'console':
# Check if the user_id contains a dot, indicating the old format
if '.' in user_id:
tenant_id, account_id = user_id.split('.')
else:
account_id = user_id
auth_header = request.headers.get('Authorization', '')
if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
decoded = PassportService().verify(auth_token)
user_id = decoded.get('user_id')
account = db.session.query(Account).filter(Account.id == account_id).first()
if account:
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
raise Forbidden('Account is banned or closed.')
workspace_id = session.get('workspace_id')
if workspace_id:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == workspace_id
).first()
if not tenant_account_join:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
else:
account.current_tenant_id = workspace_id
else:
tenant_account_join = db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id).first()
if tenant_account_join:
account.current_tenant_id = tenant_account_join.tenant_id
else:
_create_tenant_for_account(account)
session['workspace_id'] = account.current_tenant_id
current_time = datetime.utcnow()
# update last_active_at when last_active_at is more than 10 minutes ago
if current_time - account.last_active_at > timedelta(minutes=10):
account.last_active_at = current_time
db.session.commit()
# Log in the user with the updated user_id
flask_login.login_user(account, remember=True)
return account
return AccountService.load_user(user_id)
else:
return None
@login_manager.unauthorized_handler
def unauthorized_handler():
"""Handle unauthorized requests."""
@ -216,6 +164,7 @@ if app.config['TESTING']:
@app.after_request
def after_request(response):
"""Add Version headers to the response."""
response.set_cookie('remember_token', '', expires=0)
response.headers.add('X-Version', app.config['CURRENT_VERSION'])
response.headers.add('X-Env', app.config['DEPLOY_ENV'])
return response

View File

@ -3,11 +3,13 @@ import json
import math
import random
import string
import threading
import time
import uuid
import click
from tqdm import tqdm
from flask import current_app
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound
@ -23,7 +25,7 @@ from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App
import secrets
import base64
@ -239,7 +241,13 @@ def clean_unused_dataset_indexes():
kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index
if vector_index:
vector_index.delete()
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete()
# update document
update_params = {
@ -346,7 +354,8 @@ def create_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@ -364,7 +373,8 @@ def create_qdrant_indexes():
index.create_qdrant_dataset(dataset)
index_struct = {
"type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
}
dataset.index_struct = json.dumps(index_struct)
db.session.commit()
@ -373,7 +383,8 @@ def create_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
@ -414,7 +425,8 @@ def update_qdrant_indexes():
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
@ -435,11 +447,104 @@ def update_qdrant_indexes():
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue
click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))
@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = []
page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break
datasets_result = datasets.items
page += 1
for i in range(0, len(datasets_result), 5):
threads = []
sub_datasets = datasets_result[i:i + 5]
for dataset in sub_datasets:
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'normalization_count': normalization_count
})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context():
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()
if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
# index.delete_by_group_id(dataset.id)
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
normalization_count.append(1)
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def update_app_model_configs(batch_size):
@ -473,7 +578,7 @@ def update_app_model_configs(batch_size):
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
@ -485,14 +590,14 @@ def update_app_model_configs(batch_size):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green')
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
@ -512,7 +617,7 @@ def update_app_model_configs(batch_size):
app_data = db.session.query(App) \
.filter(App.id == data.app_id) \
.one()
account_data = db.session.query(Account) \
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
.filter(TenantAccountJoin.role == 'owner') \
@ -534,13 +639,85 @@ def update_app_model_configs(batch_size):
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red')
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
@click.command('migrate_default_input_to_dataset_query_variable')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def migrate_default_input_to_dataset_query_variable(batch_size):
click.secho("Starting...", fg='green')
total_records = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.count()
if total_records == 0:
click.secho("No data to migrate.", fg='green')
return
num_batches = (total_records + batch_size - 1) // batch_size
with tqdm(total=total_records, desc="Migrating Data") as pbar:
for i in range(num_batches):
offset = i * batch_size
limit = min(batch_size, total_records - offset)
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.order_by(App.created_at) \
.offset(offset).limit(limit).all()
if not data_batch:
click.secho("No more data to migrate.", fg='green')
break
try:
click.secho(f"Migrating {len(data_batch)} records...", fg='green')
for data in data_batch:
config = AppModelConfig.to_dict(data)
tools = config["agent_mode"]["tools"]
dataset_exists = "dataset" in str(tools)
if not dataset_exists:
continue
user_input_form = config.get("user_input_form", [])
for form in user_input_form:
paragraph = form.get('paragraph')
if paragraph \
and paragraph.get('variable') == 'query':
data.dataset_query_variable = 'query'
break
if paragraph \
and paragraph.get('variable') == 'default_input':
data.dataset_query_variable = 'default_input'
break
db.session.commit()
except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue
click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green')
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch))
def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
@ -551,4 +728,6 @@ def register_commands(app):
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)

View File

@ -10,9 +10,6 @@ from extensions.ext_redis import redis_client
dotenv.load_dotenv()
DEFAULTS = {
'COOKIE_HTTPONLY': 'True',
'COOKIE_SECURE': 'True',
'COOKIE_SAMESITE': 'None',
'DB_USERNAME': 'postgres',
'DB_PASSWORD': '',
'DB_HOST': 'localhost',
@ -22,10 +19,6 @@ DEFAULTS = {
'REDIS_PORT': '6379',
'REDIS_DB': '0',
'REDIS_USE_SSL': 'False',
'SESSION_REDIS_HOST': 'localhost',
'SESSION_REDIS_PORT': '6379',
'SESSION_REDIS_DB': '2',
'SESSION_REDIS_USE_SSL': 'False',
'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize',
'OAUTH_REDIRECT_INDEX_PATH': '/',
'CONSOLE_WEB_URL': 'https://cloud.dify.ai',
@ -36,9 +29,6 @@ DEFAULTS = {
'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
'SESSION_TYPE': 'sqlalchemy',
'SESSION_PERMANENT': 'True',
'SESSION_USE_SIGNER': 'True',
'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_POOL_RECYCLE': 3600,
@ -61,6 +51,8 @@ DEFAULTS = {
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1000000,
'HOSTED_ANTHROPIC_PAID_MIN_QUANTITY': 20,
'HOSTED_ANTHROPIC_PAID_MAX_QUANTITY': 100,
'HOSTED_MODERATION_ENABLED': 'False',
'HOSTED_MODERATION_PROVIDERS': '',
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
@ -100,7 +92,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.22"
self.CURRENT_VERSION = "0.3.24"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@ -113,20 +105,6 @@ class Config:
# Alternatively you can set it with `SECRET_KEY` environment variable.
self.SECRET_KEY = get_env('SECRET_KEY')
# cookie settings
self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY')
self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE')
self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE')
self.PERMANENT_SESSION_LIFETIME = timedelta(days=7)
# session settings, only support sqlalchemy, redis
self.SESSION_TYPE = get_env('SESSION_TYPE')
self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT')
self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER')
# redis settings
self.REDIS_HOST = get_env('REDIS_HOST')
self.REDIS_PORT = get_env('REDIS_PORT')
@ -135,14 +113,6 @@ class Config:
self.REDIS_DB = get_env('REDIS_DB')
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
# session redis settings
self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST')
self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT')
self.SESSION_REDIS_USERNAME = get_env('SESSION_REDIS_USERNAME')
self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD')
self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB')
self.SESSION_REDIS_USE_SSL = get_bool_env('SESSION_REDIS_USE_SSL')
# storage settings
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
@ -230,6 +200,9 @@ class Config:
self.HOSTED_ANTHROPIC_PAID_MIN_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MIN_QUANTITY'))
self.HOSTED_ANTHROPIC_PAID_MAX_QUANTITY = int(get_env('HOSTED_ANTHROPIC_PAID_MAX_QUANTITY'))
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

View File

@ -16,7 +16,7 @@ model_templates = {
},
'model_config': {
'provider': 'openai',
'model_id': 'text-davinci-003',
'model_id': 'gpt-3.5-turbo-instruct',
'configs': {
'prompt_template': '',
'prompt_variables': [],
@ -30,7 +30,7 @@ model_templates = {
},
'model': json.dumps({
"provider": "openai",
"name": "text-davinci-003",
"name": "gpt-3.5-turbo-instruct",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@ -104,7 +104,7 @@ demo_model_templates = {
'mode': 'completion',
'model_config': AppModelConfig(
provider='openai',
model_id='text-davinci-003',
model_id='gpt-3.5-turbo-instruct',
configs={
'prompt_template': "Please translate the following text into {{target_language}}:\n",
'prompt_variables': [
@ -140,7 +140,7 @@ demo_model_templates = {
pre_prompt="Please translate the following text into {{target_language}}:\n",
model=json.dumps({
"provider": "openai",
"name": "text-davinci-003",
"name": "gpt-3.5-turbo-instruct",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@ -222,7 +222,7 @@ demo_model_templates = {
'mode': 'completion',
'model_config': AppModelConfig(
provider='openai',
model_id='text-davinci-003',
model_id='gpt-3.5-turbo-instruct',
configs={
'prompt_template': "请将以下文本翻译为{{target_language}}:\n",
'prompt_variables': [
@ -258,7 +258,7 @@ demo_model_templates = {
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
model=json.dumps({
"provider": "openai",
"name": "text-davinci-003",
"name": "gpt-3.5-turbo-instruct",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,

View File

@ -81,6 +81,7 @@ class BaseApiKeyListResource(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)

View File

@ -19,40 +19,13 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site
from libs.helper import TimestampField
from extensions.ext_database import db
from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}
app_detail_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'created_at': TimestampField
}
def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
@ -62,35 +35,6 @@ def _get_app(app_id, tenant_id):
class AppListApi(Resource):
prompt_config_fields = {
'prompt_template': fields.String,
}
model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
'created_at': TimestampField
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
}
@setup_required
@login_required
@ -162,7 +106,8 @@ class AppListApi(Resource):
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=model_config_dict
config=model_config_dict,
mode=args['mode']
)
app = App(
@ -236,18 +181,6 @@ class AppListApi(Resource):
class AppTemplateApi(Resource):
template_fields = {
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'mode': fields.String,
'model_config': fields.Nested(model_config_fields),
}
template_list_fields = {
'data': fields.List(fields.Nested(template_fields)),
}
@setup_required
@login_required
@ -266,38 +199,6 @@ class AppTemplateApi(Resource):
class AppApi(Resource):
site_fields = {
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
}
app_detail_fields_with_site = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'site': fields.Nested(site_fields),
'api_base_url': fields.String,
'created_at': TimestampField
}
@setup_required
@login_required

View File

@ -13,107 +13,14 @@ from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
conversation_message_detail_fields, conversation_with_summary_pagination_fields
from libs.helper import TimestampField, datetime_string, uuid_value
from extensions.ext_database import db
from models.model import Message, MessageAnnotation, Conversation
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
feedback_stat_fields = {
'like': fields.Integer,
'dislike': fields.Integer
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'model': fields.Raw,
'user_input_form': fields.Raw,
'pre_prompt': fields.String,
'agent_mode': fields.Raw,
}
class CompletionConversationApi(Resource):
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]['text'] if value else ''
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
simple_message_detail_fields = {
'inputs': fields.Raw,
'query': fields.String,
'message': MessageTextField,
'answer': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(),
'from_account_id': fields.String,
'read_at': TimestampField,
'created_at': TimestampField,
'annotation': fields.Nested(annotation_fields, allow_null=True),
'model_config': fields.Nested(simple_model_config_fields),
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
@setup_required
@login_required
@ -191,21 +98,11 @@ class CompletionConversationApi(Resource):
class CompletionConversationDetailApi(Resource):
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'model_config': fields.Nested(model_config_fields),
'message': fields.Nested(message_detail_fields, attribute='first_message'),
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(conversation_detail_fields)
@marshal_with(conversation_message_detail_fields)
def get(self, app_id, conversation_id):
app_id = str(app_id)
conversation_id = str(conversation_id)
@ -234,44 +131,11 @@ class CompletionConversationDetailApi(Resource):
class ChatConversationApi(Resource):
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(simple_model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
@setup_required
@login_required
@account_initialization_required
@marshal_with(conversation_pagination_fields)
@marshal_with(conversation_with_summary_pagination_fields)
def get(self, app_id):
app_id = str(app_id)
@ -356,19 +220,6 @@ class ChatConversationApi(Resource):
class ChatConversationDetailApi(Resource):
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
@setup_required
@login_required

View File

@ -17,6 +17,7 @@ from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.login.login import login_required
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value, TimestampField
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
@ -27,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
from services.message_service import MessageService
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = {

View File

@ -31,7 +31,8 @@ class ModelConfigResource(Resource):
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=request.json
config=request.json,
mode=app_model.mode
)
new_app_model_config = AppModelConfig(

View File

@ -8,26 +8,11 @@ from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.app_fields import app_site_fields
from libs.helper import supported_language
from extensions.ext_database import db
from models.model import Site
app_site_fields = {
'app_id': fields.String,
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}
def parse_app_site_args():
parser = reqparse.RequestParser()

View File

@ -45,15 +45,34 @@ class OAuthDataSource(Resource):
if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
oauth_provider.save_internal_access_token(internal_secret)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
return { 'data': '' }
else:
auth_url = oauth_provider.get_authorization_url()
return redirect(auth_url)
return { 'data': auth_url }, 200
class OAuthDataSourceCallback(Resource):
def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {'error': 'Invalid provider'}, 400
if 'code' in request.args:
code = request.args.get('code')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&code={code}')
elif 'error' in request.args:
error = request.args.get('error')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error={error}')
else:
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error=Access denied')
class OAuthDataSourceBinding(Resource):
def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
@ -69,12 +88,7 @@ class OAuthDataSourceCallback(Resource):
f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
return {'error': 'OAuth data source process failed'}, 400
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=success')
elif 'error' in request.args:
error = request.args.get('error')
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source={error}')
else:
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_data_source=access_denied')
return {'result': 'success'}, 200
class OAuthDataSourceSync(Resource):
@ -101,4 +115,5 @@ class OAuthDataSourceSync(Resource):
api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')

View File

@ -6,7 +6,6 @@ from flask_restful import Resource, reqparse
import services
from controllers.console import api
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.setup import setup_required
from libs.helper import email
from libs.password import valid_password
@ -37,12 +36,12 @@ class LoginApi(Resource):
except Exception:
pass
flask_login.login_user(account, remember=args['remember_me'])
AccountService.update_last_login(account, request)
# todo: return the user info
token = AccountService.get_account_jwt_token(account)
return {'result': 'success'}
return {'result': 'success', 'data': token}
class LogoutApi(Resource):

View File

@ -2,9 +2,8 @@ import logging
from datetime import datetime
from typing import Optional
import flask_login
import requests
from flask import request, redirect, current_app, session
from flask import request, redirect, current_app
from flask_restful import Resource
from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth
@ -75,12 +74,11 @@ class OAuthCallback(Resource):
account.initialized_at = datetime.utcnow()
db.session.commit()
# login user
session.clear()
flask_login.login_user(account, remember=True)
AccountService.update_last_login(account, request)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?oauth_login=success')
token = AccountService.get_account_jwt_token(account)
return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:

View File

@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required
from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner
from extensions.ext_database import db
from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
from libs.helper import TimestampField
from models.dataset import Document
from models.source import DataSourceBinding
@ -24,37 +25,6 @@ cache = TTLCache(maxsize=None, ttl=30)
class DataSourceApi(Resource):
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)),
'total': fields.Integer
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}
@setup_required
@login_required
@ -131,28 +101,6 @@ class DataSourceApi(Resource):
class DataSourceNotionListApi(Resource):
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'is_bound': fields.Boolean,
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
}
@setup_required
@login_required

View File

@ -1,6 +1,9 @@
# -*- coding:utf-8 -*-
from flask import request
import flask_restful
from flask import request, current_app
from flask_login import current_user
from controllers.console.apikey import api_key_list, api_key_fields
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
@ -12,45 +15,16 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.indexing_runner import IndexingRunner
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType
from libs.helper import TimestampField
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from extensions.ext_database import db
from models.dataset import DocumentSegment, Document
from models.model import UploadFile
from models.model import UploadFile, ApiToken
from services.dataset_service import DatasetService, DocumentService
from services.provider_service import ProviderService
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'provider': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'app_count': fields.Integer,
'document_count': fields.Integer,
'word_count': fields.Integer,
'created_by': fields.String,
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,
"created_by": fields.String,
"created_at": TimestampField
}
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
@ -82,7 +56,8 @@ class DatasetListApi(Resource):
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
# if len(valid_model_list) == 0:
# raise ProviderNotInitializeError(
# f"No Embedding Model available. Please configure a valid provider "
@ -157,7 +132,8 @@ class DatasetApi(Resource):
# check embedding setting
provider_service = ProviderService()
# get valid model list
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value)
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
@ -271,7 +247,8 @@ class DatasetIndexingEstimateApi(Resource):
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
# validate args
DocumentService.estimate_args_validate(args)
@ -320,18 +297,6 @@ class DatasetIndexingEstimateApi(Resource):
class DatasetRelatedAppListApi(Resource):
app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
}
related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
}
@setup_required
@login_required
@ -363,24 +328,6 @@ class DatasetRelatedAppListApi(Resource):
class DatasetIndexingStatusApi(Resource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
@setup_required
@login_required
@ -400,16 +347,101 @@ class DatasetIndexingStatusApi(Resource):
DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments
document.total_segments = total_segments
documents_status.append(marshal(document, self.document_status_fields))
documents_status.append(marshal(document, document_status_fields))
data = {
'data': documents_status
}
return data
class DatasetApiKeyApi(Resource):
max_keys = 10
token_prefix = 'dataset-'
resource_type = 'dataset'
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
keys = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
all()
return {"items": keys}
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
current_key_count = db.session.query(ApiToken). \
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \
count()
if current_key_count >= self.max_keys:
flask_restful.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code='max_keys_exceeded'
)
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token, 200
class DatasetApiDeleteApi(Resource):
resource_type = 'dataset'
@setup_required
@login_required
@account_initialization_required
def delete(self, api_key_id):
api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
if current_user.current_tenant.current_role not in ['admin', 'owner']:
raise Forbidden()
key = db.session.query(ApiToken). \
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id). \
first()
if key is None:
flask_restful.abort(404, message='API key not found')
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit()
return {'result': 'success'}, 204
class DatasetApiBaseUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
return {
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL']
else request.host_url.rstrip('/')) + '/v1'
}
api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate')
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps')
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status')
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')

View File

@ -23,6 +23,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from fields.document_fields import document_with_segments_fields, document_fields, \
dataset_and_document_fields, document_status_fields
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DatasetProcessRule, Dataset
@ -32,64 +34,6 @@ from services.dataset_service import DocumentService, DatasetService
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
dataset_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
document_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
}
document_with_segments_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'completed_segments': fields.Integer,
'total_segments': fields.Integer
}
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
@ -303,11 +247,6 @@ class DatasetDocumentListApi(Resource):
class DatasetInitApi(Resource):
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'documents': fields.List(fields.Nested(document_fields)),
'batch': fields.String
}
@setup_required
@login_required
@ -504,24 +443,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
class DocumentBatchIndexingStatusApi(DocumentResource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}
@setup_required
@login_required
@ -541,7 +462,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
documents_status.append(marshal(document, self.document_status_fields))
documents_status.append(marshal(document, document_status_fields))
data = {
'data': documents_status
}
@ -549,20 +470,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
class DocumentIndexingStatusApi(DocumentResource):
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
@setup_required
@login_required
@ -586,7 +493,7 @@ class DocumentIndexingStatusApi(DocumentResource):
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
return marshal(document, self.document_status_fields)
return marshal(document, document_status_fields)
class DocumentDetailApi(DocumentResource):

View File

@ -3,7 +3,7 @@ import uuid
from datetime import datetime
from flask import request
from flask_login import current_user
from flask_restful import Resource, reqparse, fields, marshal
from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import NotFound, Forbidden
import services
@ -17,6 +17,7 @@ from core.model_providers.model_factory import ModelFactory
from core.login.login import login_required
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from models.dataset import DocumentSegment
from libs.helper import TimestampField
@ -26,36 +27,6 @@ from tasks.disable_segment_from_index_task import disable_segment_from_index_tas
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
import pandas as pd
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField
}
segment_list_response = {
'data': fields.List(fields.Nested(segment_fields)),
'has_more': fields.Boolean,
'limit': fields.Integer
}
class DatasetDocumentSegmentListApi(Resource):
@setup_required

View File

@ -1,28 +1,19 @@
import datetime
import hashlib
import tempfile
import chardet
import time
import uuid
from pathlib import Path
from cachetools import TTLCache
from flask import request, current_app
from flask_login import current_user
import services
from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
UnsupportedFileTypeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.data_loader.file_extractor import FileExtractor
from extensions.ext_storage import storage
from libs.helper import TimestampField
from extensions.ext_database import db
from models.model import UploadFile
from fields.file_fields import upload_config_fields, file_fields
from services.file_service import FileService
cache = TTLCache(maxsize=None, ttl=30)
@ -31,10 +22,6 @@ PREVIEW_WORDS_LIMIT = 3000
class FileApi(Resource):
upload_config_fields = {
'file_size_limit': fields.Integer,
'batch_count_limit': fields.Integer
}
@setup_required
@login_required
@ -48,16 +35,6 @@ class FileApi(Resource):
'batch_count_limit': batch_count_limit
}, 200
file_fields = {
'id': fields.String,
'name': fields.String,
'size': fields.Integer,
'extension': fields.String,
'mime_type': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
@setup_required
@login_required
@account_initialization_required
@ -73,45 +50,13 @@ class FileApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
file_content = file.read()
file_size = len(file_content)
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
if file_size > file_size_limit:
message = "({file_size} > {file_size_limit})"
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]
if extension.lower() not in ALLOWED_EXTENSIONS:
try:
upload_file = FileService.upload_file(file)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
# save file to storage
storage.save(file_key, file_content)
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file.filename,
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by=current_user.id,
created_at=datetime.datetime.utcnow(),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
)
db.session.add(upload_file)
db.session.commit()
return upload_file, 201
@ -121,26 +66,7 @@ class FilePreviewApi(Resource):
@account_initialization_required
def get(self, file_id):
file_id = str(file_id)
key = file_id + request.path
cached_response = cache.get(key)
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
return cached_response['response']
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found")
# extract text from file
extension = upload_file.extension
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
text = FileExtractor.load(upload_file, return_text=True)
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
text = FileService.get_file_preview(file_id)
return {'content': text}

View File

@ -2,7 +2,7 @@ import logging
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, marshal, fields
from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
import services
@ -14,48 +14,10 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \
LLMBadRequestError
from libs.helper import TimestampField
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
document_fields = {
'id': fields.String,
'data_source_type': fields.String,
'name': fields.String,
'doc_type': fields.String,
}
segment_fields = {
'id': fields.String,
'position': fields.Integer,
'document_id': fields.String,
'content': fields.String,
'answer': fields.String,
'word_count': fields.Integer,
'tokens': fields.Integer,
'keywords': fields.List(fields.String),
'index_node_id': fields.String,
'index_node_hash': fields.String,
'hit_count': fields.Integer,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'status': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'indexing_at': TimestampField,
'completed_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'document': fields.Nested(document_fields),
}
hit_testing_record_fields = {
'segment': fields.Nested(segment_fields),
'score': fields.Float,
'tsne_position': fields.Raw
}
class HitTestingApi(Resource):

View File

@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationListApi(InstalledAppResource):
@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource):
class ConversationRenameApi(InstalledAppResource):
@marshal_with(conversation_fields)
@marshal_with(simple_conversation_fields)
def post(self, installed_app, c_id):
app_model = installed_app.app
if app_model.mode != 'chat':

View File

@ -11,32 +11,11 @@ from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.helper import TimestampField
from models.model import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
app_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String
}
installed_app_fields = {
'id': fields.String,
'app': fields.Nested(app_fields),
'app_owner_tenant_id': fields.String,
'is_pinned': fields.Boolean,
'last_used_at': TimestampField,
'editable': fields.Boolean,
'uninstallable': fields.Boolean,
}
installed_app_list_fields = {
'installed_apps': fields.List(fields.Nested(installed_app_fields))
}
class InstalledAppsListApi(Resource):
@login_required

View File

@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste
from controllers.console.explore.wraps import InstalledAppResource
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs.helper import uuid_value, TimestampField
from services.completion_service import CompletionService
from services.errors.app import MoreLikeThisDisabledError
@ -26,45 +27,6 @@ from services.message_service import MessageService
class MessageListApi(InstalledAppResource):
feedback_fields = {
'rating': fields.String
}
retriever_resource_fields = {
'id': fields.String,
'message_id': fields.String,
'position': fields.Integer,
'dataset_id': fields.String,
'dataset_name': fields.String,
'document_id': fields.String,
'document_name': fields.String,
'data_source_type': fields.String,
'segment_id': fields.String,
'score': fields.Float,
'hit_count': fields.Integer,
'word_count': fields.Integer,
'segment_position': fields.Integer,
'index_node_hash': fields.String,
'content': fields.String,
'created_at': TimestampField
}
message_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField
}
message_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(message_fields))
}
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, installed_app):

View File

@ -1,7 +1,6 @@
# -*- coding:utf-8 -*-
from functools import wraps
import flask_login
from flask import request, current_app
from flask_restful import Resource, reqparse
@ -58,9 +57,6 @@ class SetupApi(Resource):
)
setup()
# Login
flask_login.login_user(account)
AccountService.update_last_login(account, request)
return {'result': 'success'}, 201

View File

@ -33,7 +33,6 @@ class UniversalChatApi(UniversalChatResource):
args = parser.parse_args()
app_model_config = app_model.app_model_config
app_model_config
# update app model config
args['model_config'] = app_model_config.to_dict()

View File

@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.universal_chat.wraps import UniversalChatResource
from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \
conversation_with_model_config_fields
from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField,
'model_config': fields.Raw,
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class UniversalChatConversationListApi(UniversalChatResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
@marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields)
def get(self, universal_app):
app_model = universal_app
@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource):
class UniversalChatConversationRenameApi(UniversalChatResource):
@marshal_with(conversation_fields)
@marshal_with(conversation_with_model_config_fields)
def post(self, universal_app, c_id):
app_model = universal_app
conversation_id = str(c_id)

View File

@ -246,7 +246,8 @@ class ModelProviderModelParameterRuleApi(Resource):
'enabled': v.enabled,
'min': v.min,
'max': v.max,
'default': v.default
'default': v.default,
'precision': v.precision
}
for k, v in vars(parameter_rules).items()
}
@ -285,6 +286,25 @@ class ModelProviderFreeQuotaSubmitApi(Resource):
return result
class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_name: str):
parser = reqparse.RequestParser()
parser.add_argument('token', type=str, required=False, nullable=True, location='args')
args = parser.parse_args()
provider_service = ProviderService()
result = provider_service.free_quota_qualification_verify(
tenant_id=current_user.current_tenant_id,
provider_name=provider_name,
token=args['token']
)
return result
api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
@ -300,3 +320,5 @@ api.add_resource(ModelProviderPaymentCheckoutUrlApi,
'/workspaces/current/model-providers/<string:provider_name>/checkout-url')
api.add_resource(ModelProviderFreeQuotaSubmitApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
'/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')

View File

@ -9,4 +9,4 @@ api = ExternalApi(bp)
from .app import completion, app, conversation, message, audio
from .dataset import document
from .dataset import document, segment, dataset

View File

@ -8,25 +8,11 @@ from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value
import services
from services.conversation_service import ConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationApi(AppApiResource):
@ -50,7 +36,7 @@ class ConversationApi(AppApiResource):
raise NotFound("Last Conversation Not Exists.")
class ConversationDetailApi(AppApiResource):
@marshal_with(conversation_fields)
@marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()
@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource):
class ConversationRenameApi(AppApiResource):
@marshal_with(conversation_fields)
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()

View File

@ -0,0 +1,84 @@
from flask import request
from flask_restful import reqparse, marshal
import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from core.model_providers.models.entity.model_params import ModelType
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from models.account import Account, TenantAccountJoin
from models.dataset import Dataset
from services.dataset_service import DatasetService
from services.provider_service import ProviderService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError('Name must be between 1 to 40 characters.')
return name
class DatasetApi(DatasetApiResource):
"""Resource for get datasets."""
def get(self, tenant_id):
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
provider = request.args.get('provider', default="vendor")
datasets, total = DatasetService.get_datasets(page, limit, provider,
tenant_id, current_user)
# check embedding setting
provider_service = ProviderService()
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
ModelType.EMBEDDINGS.value)
model_names = []
for valid_model in valid_model_list:
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item['indexing_technique'] == 'high_quality':
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item['embedding_available'] = True
else:
item['embedding_available'] = False
else:
item['embedding_available'] = True
response = {
'data': data,
'has_more': len(datasets) == limit,
'limit': limit,
'total': total,
'page': page
}
return response, 200
"""Resource for datasets."""
def post(self, tenant_id):
parser = reqparse.RequestParser()
parser.add_argument('name', nullable=False, required=True,
help='type is required. Name must be between 1 to 40 characters.',
type=_validate_name)
parser.add_argument('indexing_technique', type=str, location='json',
choices=('high_quality', 'economy'),
help='Invalid indexing technique.')
args = parser.parse_args()
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
name=args['name'],
indexing_technique=args['indexing_technique'],
account=current_user
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 200
api.add_resource(DatasetApi, '/datasets')

View File

@ -1,114 +1,291 @@
import datetime
import json
import uuid
from flask import current_app
from flask_restful import reqparse
from flask import current_app, request
from flask_restful import reqparse, marshal
from sqlalchemy import desc
from werkzeug.exceptions import NotFound
import services.dataset_service
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
DatasetNotInitedError
NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.document_fields import document_fields, document_status_fields
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DocumentService
from services.file_service import FileService
class DocumentListApi(DatasetApiResource):
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents."""
def post(self, dataset):
"""Create document."""
def post(self, tenant_id, dataset_id):
"""Create document by text."""
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, nullable=False, location='json')
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('doc_type', type=str, location='json')
parser.add_argument('doc_metadata', type=dict, location='json')
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
parser.add_argument('original_document_id', type=str, required=False, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset.indexing_technique:
raise DatasetNotInitedError("Dataset indexing technique must be set.")
if not dataset:
raise ValueError('Dataset is not exist.')
doc_type = args.get('doc_type')
doc_metadata = args.get('doc_metadata')
if not dataset.indexing_technique and not args['indexing_technique']:
raise ValueError('indexing_technique is required.')
if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA:
raise ValueError('Invalid doc_type.')
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt'
# save file to storage
storage.save(file_key, args.get('text'))
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=dataset.tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=args.get('name') + '.txt',
size=len(args.get('text')),
extension='txt',
mime_type='text/plain',
created_by=dataset.created_by,
created_at=datetime.datetime.utcnow(),
used=True,
used_by=dataset.created_by,
used_at=datetime.datetime.utcnow()
)
db.session.add(upload_file)
db.session.commit()
document_data = {
'data_source': {
'type': 'upload_file',
'info': [
{
'upload_file_id': upload_file.id
}
]
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
data_source = {
'type': 'upload_file',
'info_list': {
'data_source_type': 'upload_file',
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=document_data,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule,
document_data=args,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
document.doc_metadata = {}
for key, value_type in metadata_schema.items():
value = doc_metadata.get(key)
if value is not None and isinstance(value, value_type):
document.doc_metadata[key] = value
document.doc_type = doc_type
document.updated_at = datetime.datetime.utcnow()
db.session.commit()
return {'id': document.id}
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentApi(DatasetApiResource):
def delete(self, dataset, document_id):
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
def post(self, tenant_id, dataset_id, document_id):
"""Update document by text."""
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, nullable=True, location='json')
parser.add_argument('text', type=str, required=False, nullable=True, location='json')
parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json')
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if args['text']:
upload_file = FileService.upload_text(args.get('text'), args.get('name'))
data_source = {
'type': 'upload_file',
'info_list': {
'data_source_type': 'upload_file',
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
args['original_document_id'] = str(document_id)
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentAddByFileApi(DatasetApiResource):
"""Resource for documents."""
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}
if 'data' in request.form:
args = json.loads(request.form['data'])
if 'doc_form' not in args:
args['doc_form'] = 'text_model'
if 'doc_language' not in args:
args['doc_language'] = 'English'
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if not dataset.indexing_technique and not args['indexing_technique']:
raise ValueError('indexing_technique is required.')
# save file info
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
data_source = {
'type': 'upload_file',
'info_list': {
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
args = {}
if 'data' in request.form:
args = json.loads(request.form['data'])
if 'doc_form' not in args:
args['doc_form'] = 'text_model'
if 'doc_language' not in args:
args['doc_language'] = 'English'
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
if 'file' in request.files:
# save file info
file = request.files['file']
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
data_source = {
'type': 'upload_file',
'info_list': {
'file_info_list': {
'file_ids': [upload_file.id]
}
}
}
args['data_source'] = data_source
# validate args
args['original_document_id'] = str(document_id)
DocumentService.document_create_args_validate(args)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None,
created_from='api'
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {
'document': marshal(document, document_fields),
'batch': batch
}
return documents_and_batch_fields, 200
class DocumentDeleteApi(DatasetApiResource):
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise ValueError('Dataset is not exist.')
document = DocumentService.get_document(dataset.id, document_id)
@ -126,8 +303,85 @@ class DocumentApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError('Cannot delete document during indexing.')
return {'result': 'success'}, 204
return {'result': 'success'}, 200
api.add_resource(DocumentListApi, '/documents')
api.add_resource(DocumentApi, '/documents/<uuid:document_id>')
class DocumentListApi(DatasetApiResource):
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
page = request.args.get('page', default=1, type=int)
limit = request.args.get('limit', default=20, type=int)
search = request.args.get('keyword', default=None, type=str)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
query = Document.query.filter_by(
dataset_id=str(dataset_id), tenant_id=tenant_id)
if search:
search = f'%{search}%'
query = query.filter(Document.name.like(search))
query = query.order_by(desc(Document.created_at))
paginated_documents = query.paginate(
page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
response = {
'data': marshal(documents, document_fields),
'has_more': len(documents) == limit,
'limit': limit,
'total': paginated_documents.total,
'page': page
}
return response
class DocumentIndexingStatusApi(DatasetApiResource):
def get(self, tenant_id, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# get documents
documents = DocumentService.get_batch_documents(dataset_id, batch)
if not documents:
raise NotFound('Documents not found.')
documents_status = []
for document in documents:
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id),
DocumentSegment.status != 're_segment').count()
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = 'paused'
documents_status.append(marshal(document, document_status_fields))
data = {
'data': documents_status
}
return data
api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text')
api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file')
api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text')
api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file')
api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>')
api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents')
api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status')

View File

@ -1,20 +1,73 @@
# -*- coding:utf-8 -*-
from libs.exception import BaseHTTPException
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = 'file_too_large'
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415
class HighQualityDatasetOnlyError(BaseHTTPException):
error_code = 'high_quality_dataset_only'
description = "Current operation only supports 'high-quality' datasets."
code = 400
class DatasetNotInitializedError(BaseHTTPException):
error_code = 'dataset_not_initialized'
description = "The dataset is still being initialized or indexing. Please wait a moment."
code = 400
class ArchivedDocumentImmutableError(BaseHTTPException):
error_code = 'archived_document_immutable'
description = "Cannot operate when document was archived."
description = "The archived document is not editable."
code = 403
class DatasetNameDuplicateError(BaseHTTPException):
error_code = 'dataset_name_duplicate'
description = "The dataset name already exists. Please modify your dataset name."
code = 409
class InvalidActionError(BaseHTTPException):
error_code = 'invalid_action'
description = "Invalid action."
code = 400
class DocumentAlreadyFinishedError(BaseHTTPException):
error_code = 'document_already_finished'
description = "The document has been processed. Please refresh the page or go to the document details."
code = 400
class DocumentIndexingError(BaseHTTPException):
error_code = 'document_indexing'
description = "Cannot operate document during indexing."
code = 403
description = "The document is being processed and cannot be edited."
code = 400
class DatasetNotInitedError(BaseHTTPException):
error_code = 'dataset_not_inited'
description = "The dataset is still being initialized or indexing. Please wait a moment."
code = 403
class InvalidMetadataError(BaseHTTPException):
error_code = 'invalid_metadata'
description = "The metadata content is incorrect. Please check and verify."
code = 400

View File

@ -0,0 +1,59 @@
from flask_login import current_user
from flask_restful import reqparse, marshal
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset
from services.dataset_service import DocumentService, SegmentService
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
def post(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')

View File

@ -2,11 +2,14 @@
from datetime import datetime
from functools import wraps
from flask import request
from flask import request, current_app
from flask_login import user_logged_in
from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from core.login.login import _get_user
from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, Account
from models.dataset import Dataset
from models.model import ApiToken, App
@ -43,12 +46,24 @@ def validate_dataset_token(view=None):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token('dataset')
dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first()
if not dataset:
raise NotFound()
return view(dataset, *args, **kwargs)
tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \
.filter(Tenant.id == api_token.tenant_id) \
.filter(TenantAccountJoin.tenant_id == Tenant.id) \
.filter(TenantAccountJoin.role == 'owner') \
.one_or_none()
if tenant_account_join:
tenant, ta = tenant_account_join
account = Account.query.filter_by(id=ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account)
user_logged_in.send(current_app._get_current_object(), user=_get_user())
else:
raise Unauthorized("Tenant owner account is not exist.")
else:
raise Unauthorized("Tenant is not exist.")
return view(api_token.tenant_id, *args, **kwargs)
return decorated
if view:

View File

@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound
from controllers.web import api
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import TimestampField, uuid_value
from services.conversation_service import ConversationService
from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError
from services.web_conversation_service import WebConversationService
conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_fields))
}
class ConversationListApi(WebApiResource):
@ -73,7 +59,7 @@ class ConversationApi(WebApiResource):
class ConversationRenameApi(WebApiResource):
@marshal_with(conversation_fields)
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()

View File

@ -16,6 +16,8 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
from langchain.agents import AgentExecutor as LCAgentExecutor
from core.helper import moderation
from core.model_providers.error import LLMError
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@ -116,6 +118,18 @@ class AgentExecutor:
return self.agent.should_use_agent(query)
def run(self, query: str) -> AgentExecuteResult:
moderation_result = moderation.check_moderation(
self.configuration.model_instance.model_provider,
query
)
if not moderation_result:
return AgentExecuteResult(
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
strategy=self.configuration.strategy,
configuration=self.configuration
)
agent_executor = LCAgentExecutor.from_agent_and_tools(
agent=self.agent,
tools=self.configuration.tools,
@ -128,7 +142,9 @@ class AgentExecutor:
try:
output = agent_executor.run(query)
except Exception:
except LLMError as ex:
raise ex
except Exception as ex:
logging.exception("agent_executor run failed")
output = None

View File

@ -6,7 +6,7 @@ from typing import Any, Dict, List, Union, Optional
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.callback_handler.entity.agent_loop import AgentLoop
from core.conversation_message_task import ConversationMessageTask
@ -18,9 +18,9 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None:
"""Initialize callback handler."""
self.model_instant = model_instant
self.model_instance = model_instance
self.conversation_message_task = conversation_message_task
self._agent_loops = []
self._current_loop = None
@ -46,6 +46,21 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Whether to ignore chain callbacks."""
return True
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([message.content for message in messages[0]]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
@ -70,7 +85,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
else:
self._current_loop.prompt_tokens = self.model_instant.get_num_tokens(
self._current_loop.prompt_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.prompt)]
)
completion_generation = response.generations[0][0]
@ -87,7 +102,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
if response.llm_output:
self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self._current_loop.completion_tokens = self.model_instant.get_num_tokens(
self._current_loop.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self._current_loop.completion)]
)
@ -162,7 +177,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)
@ -193,7 +208,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
)
self.conversation_message_task.on_agent_end(
self._message_agent_thought, self.model_instant, self._current_loop
self._message_agent_thought, self.model_instance, self._current_loop
)
self._agent_loops.append(self._current_loop)

View File

@ -6,4 +6,3 @@ class LLMMessage(BaseModel):
prompt_tokens: int = 0
completion: str = ''
completion_tokens: int = 0
latency: float = 0.0

View File

@ -1,5 +1,4 @@
import logging
import time
from typing import Any, Dict, List, Union
from langchain.callbacks.base import BaseCallbackHandler
@ -32,7 +31,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
messages: List[List[BaseMessage]],
**kwargs: Any
) -> Any:
self.start_at = time.perf_counter()
real_prompts = []
for message in messages[0]:
if message.type == 'human':
@ -53,8 +51,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
self.start_at = time.perf_counter()
self.llm_message.prompt = [{
"role": 'user',
"text": prompts[0]
@ -63,14 +59,22 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
self.llm_message.completion = response.generations[0][0].text
self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)])
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
if 'completion_tokens' in response.llm_output['token_usage']:
self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
else:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)])
self.conversation_message_task.save_message(self.llm_message)
@ -89,8 +93,6 @@ class LLMCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
end_at = time.perf_counter()
self.llm_message.latency = end_at - self.start_at
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)

View File

@ -1,15 +1,33 @@
import enum
import logging
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation import openai_moderation
class SensitiveWordAvoidanceRule(BaseModel):
class Type(enum.Enum):
MODERATION = "moderation"
KEYWORDS = "keywords"
type: Type
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
extra_params: dict = {}
class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
sensitive_words: List[str] = []
canned_response: str = None
model_instance: BaseLLM
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
@property
def _chain_type(self) -> str:
@ -31,11 +49,24 @@ class SensitiveWordAvoidanceChain(Chain):
"""
return [self.output_key]
def _check_sensitive_word(self, text: str) -> str:
for word in self.sensitive_words:
def _check_sensitive_word(self, text: str) -> bool:
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
if word in text:
return self.canned_response
return text
return False
return True
def _check_moderation(self, text: str) -> bool:
moderation_model_instance = ModelFactory.get_moderation_model(
tenant_id=self.model_instance.model_provider.provider.tenant_id,
model_provider_name='openai',
model_name=openai_moderation.DEFAULT_MODEL
)
try:
return moderation_model_instance.run(text=text)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
def _call(
self,
@ -43,5 +74,19 @@ class SensitiveWordAvoidanceChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
output = self._check_sensitive_word(text)
return {self.output_key: output}
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
result = self._check_sensitive_word(text)
else:
result = self._check_moderation(text)
if not result:
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}
class SensitiveWordAvoidanceError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@ -1,24 +1,22 @@
import json
import logging
import re
from typing import Optional, List, Union, Tuple
from typing import Optional, List, Union
from langchain.schema import BaseMessage
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
@ -79,28 +77,55 @@ class Completion:
app_model_config=app_model_config
)
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
if sensitive_word_avoidance_chain:
query = sensitive_word_avoidance_chain.run(query)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback
)
# run agent executor
agent_execute_result = None
if agent_executor:
should_use_agent = agent_executor.should_use_agent(query)
if should_use_agent:
agent_execute_result = agent_executor.run(query)
# run the final llm
try:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
try:
query = sensitive_word_avoidance_chain.run(query)
except SensitiveWordAvoidanceError as ex:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=ex.message
)
return
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
conversation_message_task=conversation_message_task,
memory=memory,
rest_tokens=rest_tokens_for_context_and_memory,
chain_callback=chain_callback,
retriever_from=retriever_from
)
query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
# run agent executor
agent_execute_result = None
if query_for_agent and agent_executor:
should_use_agent = agent_executor.should_use_agent(query_for_agent)
if should_use_agent:
agent_execute_result = agent_executor.run(query_for_agent)
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
# run the final llm
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
@ -109,7 +134,8 @@ class Completion:
inputs=inputs,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory
memory=memory,
fake_response=fake_response
)
except ConversationTaskStoppedException:
return
@ -118,20 +144,21 @@ class Completion:
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
if app.mode != 'completion':
return query
return inputs.get(app_model_config.dataset_query_variable, "")
@classmethod
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]):
# When no extra pre prompt is specified,
# the output of the agent can be used directly as the main output content without calling LLM again
fake_response = None
if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, PlanningStrategy.REACT_ROUTER]:
fake_response = agent_execute_result.output
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,

View File

@ -1,5 +1,5 @@
import decimal
import json
import time
from typing import Optional, Union, List
from core.callback_handler.entity.agent_loop import AgentLoop
@ -23,6 +23,8 @@ class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
self.start_at = time.perf_counter()
self.task_id = task_id
self.app = app
@ -61,6 +63,7 @@ class ConversationMessageTask:
)
def init(self):
override_model_configs = None
if self.is_override:
override_model_configs = self.app_model_config.to_dict()
@ -109,7 +112,7 @@ class ConversationMessageTask:
)
db.session.add(self.conversation)
db.session.flush()
db.session.commit()
self.message = Message(
app_id=self.app_model_config.app_id,
@ -137,7 +140,7 @@ class ConversationMessageTask:
)
db.session.add(self.message)
db.session.flush()
db.session.commit()
def append_message_text(self, text: str):
if text is not None:
@ -165,7 +168,7 @@ class ConversationMessageTask:
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
self.message.answer_price_unit = answer_price_unit
self.message.provider_response_latency = llm_message.latency
self.message.provider_response_latency = time.perf_counter() - self.start_at
self.message.total_price = total_price
db.session.commit()
@ -188,12 +191,13 @@ class ConversationMessageTask:
)
db.session.add(message_chain)
db.session.flush()
db.session.commit()
return message_chain
def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
message_chain.output = json.dumps(chain_result.completion)
db.session.commit()
self._pub_handler.pub_chain(message_chain)
@ -214,24 +218,24 @@ class ConversationMessageTask:
)
db.session.add(message_agent_thought)
db.session.flush()
db.session.commit()
self._pub_handler.pub_agent_thought(message_agent_thought)
return message_agent_thought
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instant.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instant.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instant.get_price_unit(MessageType.ASSISTANT)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instant.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instant.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
message_agent_thought.observation = agent_loop.tool_output
@ -245,8 +249,8 @@ class ConversationMessageTask:
message_agent_thought.latency = agent_loop.latency
message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
message_agent_thought.total_price = loop_total_price
message_agent_thought.currency = agent_model_instant.get_currency()
db.session.flush()
message_agent_thought.currency = agent_model_instance.get_currency()
db.session.commit()
def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
dataset_query = DatasetQuery(
@ -259,6 +263,7 @@ class ConversationMessageTask:
)
db.session.add(dataset_query)
db.session.commit()
def on_dataset_query_finish(self, resource: List):
if resource and len(resource) > 0:
@ -282,7 +287,7 @@ class ConversationMessageTask:
created_by=self.user.id
)
db.session.add(dataset_retriever_resource)
db.session.flush()
db.session.commit()
self.retriever_resource = resource
def message_end(self):

View File

@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
SEARCH_URL = "https://api.notion.com/v1/search"
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']

View File

@ -0,0 +1,34 @@
import logging
import openai
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.providers.hosted import hosted_config, hosted_model_providers
from models.provider import ProviderType
def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
if hosted_config.moderation.enabled is True and hosted_model_providers.openai:
if model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and model_provider.provider_name in hosted_config.moderation.providers:
# 2000 text per chunk
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for text_chunk in chunks:
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True

View File

@ -16,6 +16,10 @@ class BaseIndex(ABC):
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError
@abstractmethod
def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError
@ -28,6 +32,10 @@ class BaseIndex(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_document_id(self, document_id: str):
raise NotImplementedError

View File

@ -46,6 +46,32 @@ class KeywordTableIndex(BaseIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()
self._save_dataset_keyword_table(keyword_table)
return self
def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler()
@ -120,6 +146,12 @@ class KeywordTableIndex(BaseIndex):
db.session.delete(dataset_keyword_table)
db.session.commit()
def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
'__type__': 'keyword_table',
@ -214,11 +246,28 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
def multi_create_segment_keywords(self, pre_segment_data_list: list):
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = self._get_dataset_keyword_table()
for pre_segment_data in pre_segment_data_list:
segment = pre_segment_data['segment']
if pre_segment_data['keywords']:
segment.keywords = pre_segment_data['keywords']
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id,
pre_segment_data['keywords'])
else:
keywords = keyword_table_handler.extract_keywords(segment.content,
self._config.max_keywords_per_chunk)
segment.keywords = list(keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
class KeywordTableRetriever(BaseRetriever, BaseModel):
index: KeywordTableIndex
search_kwargs: dict = Field(default_factory=dict)

View File

@ -10,7 +10,7 @@ from weaviate import UnexpectedStatusCodeException
from core.index.base import BaseIndex
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
from models.dataset import Document as DatasetDocument
@ -110,6 +110,14 @@ class BaseVectorIndex(BaseIndex):
for node_id in ids:
vector_store.del_text(node_id)
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
if self.dataset.collection_binding_id:
vector_store.delete_by_group_id(group_id)
else:
vector_store.delete()
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@ -243,3 +251,53 @@ class BaseVectorIndex(BaseIndex):
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")
dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()
documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)
documents.append(document)
if documents:
try:
self.add_texts(documents)
except Exception as e:
raise e
logging.info(f"Dataset {dataset.id} recreate successfully.")
def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")
self.delete()
dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()
logging.info(f"Dataset {dataset.id} recreate successfully.")

View File

@ -69,6 +69,19 @@ class MilvusVectorIndex(BaseVectorIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=collection_name,
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:

View File

@ -28,6 +28,7 @@ from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@ -84,6 +85,7 @@ class Qdrant(VectorStore):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR_NAME = None
def __init__(
@ -93,9 +95,12 @@ class Qdrant(VectorStore):
embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME,
embedding_function: Optional[Callable] = None, # deprecated
is_new_collection: bool = False
):
"""Initialize with necessary components."""
try:
@ -129,7 +134,10 @@ class Qdrant(VectorStore):
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.group_payload_key = group_payload_key or self.GROUP_KEY
self.vector_name = vector_name or self.VECTOR_NAME
self.group_id = group_id
self.is_new_collection= is_new_collection
if embedding_function is not None:
warnings.warn(
@ -170,6 +178,8 @@ class Qdrant(VectorStore):
batch_size:
How many vectors upload per-request.
Default: 64
group_id:
collection group
Returns:
List of ids from adding the texts into the vectorstore.
@ -182,7 +192,11 @@ class Qdrant(VectorStore):
collection_name=self.collection_name, points=points, **kwargs
)
added_ids.extend(batch_ids)
# if is new collection, create payload index on group_id
if self.is_new_collection:
self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
return added_ids
@sync_call_fallback
@ -970,6 +984,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
@ -1034,6 +1050,11 @@ class Qdrant(VectorStore):
metadata_payload_key:
A payload key used to store the metadata of the document.
Default: "metadata"
group_payload_key:
A payload key used to store the content of the document.
Default: "group_id"
group_id:
collection group id
vector_name:
Name of the vector to be used internally in Qdrant.
Default: None
@ -1107,6 +1128,8 @@ class Qdrant(VectorStore):
distance_func,
content_payload_key,
metadata_payload_key,
group_payload_key,
group_id,
vector_name,
shard_number,
replication_factor,
@ -1321,6 +1344,8 @@ class Qdrant(VectorStore):
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
@ -1350,6 +1375,7 @@ class Qdrant(VectorStore):
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
is_new_collection = False
client = qdrant_client.QdrantClient(
location=location,
url=url,
@ -1364,70 +1390,12 @@ class Qdrant(VectorStore):
path=path,
**kwargs,
)
try:
# Skip any validation in case of forced collection recreate.
if force_recreate:
raise ValueError
# Get the vector configuration of the existing collection and vector, if it
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
except (UnexpectedResponse, RpcError, ValueError):
all_collection_name = []
collections_response = client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
@ -1454,6 +1422,68 @@ class Qdrant(VectorStore):
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
is_new_collection = True
if force_recreate:
raise ValueError
# Get the vector configuration of the existing collection and vector, if it
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
qdrant = cls(
client=client,
collection_name=collection_name,
@ -1462,6 +1492,9 @@ class Qdrant(VectorStore):
metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func,
vector_name=vector_name,
group_id=group_id,
group_payload_key=group_payload_key,
is_new_collection=is_new_collection
)
return qdrant
@ -1516,6 +1549,8 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> List[dict]:
payloads = []
for i, text in enumerate(texts):
@ -1529,6 +1564,7 @@ class Qdrant(VectorStore):
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
@ -1578,7 +1614,7 @@ class Qdrant(VectorStore):
else:
out.append(
rest.FieldCondition(
key=f"{self.metadata_payload_key}.{key}",
key=key,
match=rest.MatchValue(value=value),
)
)
@ -1654,6 +1690,7 @@ class Qdrant(VectorStore):
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
@ -1684,6 +1721,8 @@ class Qdrant(VectorStore):
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
self.group_id,
self.group_payload_key
),
)
]

View File

@ -6,18 +6,20 @@ from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore
from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff
from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
@ -43,16 +45,21 @@ class QdrantVectorIndex(BaseVectorIndex):
return 'qdrant'
def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
return dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
@ -68,6 +75,27 @@ class QdrantVectorIndex(BaseVectorIndex):
collection_name=self.get_index_name(self.dataset),
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)
@ -78,8 +106,6 @@ class QdrantVectorIndex(BaseVectorIndex):
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params()
)
@ -88,16 +114,15 @@ class QdrantVectorIndex(BaseVectorIndex):
client=client,
collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings,
content_payload_key='page_content'
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
)
def _get_vector_store_class(self) -> type:
return QdrantVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@ -114,9 +139,6 @@ class QdrantVectorIndex(BaseVectorIndex):
))
def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
@ -132,6 +154,35 @@ class QdrantVectorIndex(BaseVectorIndex):
],
))
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

View File

@ -91,6 +91,20 @@ class WeaviateVectorIndex(BaseVectorIndex):
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)
return self
def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:

View File

@ -1,11 +1,10 @@
import os
from functools import wraps
import flask_login
from flask import current_app
from flask import g
from flask import has_request_context
from flask import request
from flask import request, session
from flask_login import user_logged_in
from flask_login.config import EXEMPT_METHODS
from werkzeug.exceptions import Unauthorized

View File

@ -8,6 +8,7 @@ from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.models.speech2text.base import BaseSpeech2Text
from extensions.ext_database import db
from models.provider import TenantDefaultModel
@ -180,7 +181,7 @@ class ModelFactory:
def get_moderation_model(cls,
tenant_id: str,
model_provider_name: str,
model_name: str) -> Optional[BaseProviderModel]:
model_name: str) -> Optional[BaseModeration]:
"""
get moderation model.

View File

@ -45,6 +45,9 @@ class ModelProviderFactory:
elif provider_name == 'wenxin':
from core.model_providers.providers.wenxin_provider import WenxinProvider
return WenxinProvider
elif provider_name == 'zhipuai':
from core.model_providers.providers.zhipuai_provider import ZhipuAIProvider
return ZhipuAIProvider
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider

View File

@ -0,0 +1,22 @@
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
from core.model_providers.models.embedding.base import BaseEmbedding
class HuggingfaceEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = HuggingfaceHubEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Huggingface embedding: {str(ex)}")

View File

@ -0,0 +1,22 @@
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.zhipuai_embedding import ZhipuAIEmbeddings
class ZhipuAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = ZhipuAIEmbeddings(
model=name,
**credentials,
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI embedding: {str(ex)}")

View File

@ -49,6 +49,7 @@ class KwargRule(Generic[T], BaseModel):
max: Optional[T] = None
default: Optional[T] = None
alias: Optional[str] = None
precision: Optional[int] = None
class ModelKwargsRules(BaseModel):

View File

@ -1,6 +1,7 @@
import json
import os
import re
import time
from abc import abstractmethod
from typing import List, Optional, Any, Union, Tuple
import decimal
@ -10,6 +11,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
@ -19,6 +21,8 @@ from core.prompt.prompt_template import JinjaPromptTemplate
from core.third_party.langchain.llms.fake import FakeLLM
import logging
from extensions.ext_database import db
logger = logging.getLogger(__name__)
@ -116,9 +120,20 @@ class BaseLLM(BaseProviderModel):
:param callbacks:
:return:
"""
moderation_result = moderation.check_moderation(
self.model_provider,
"\n".join([message.content for message in messages])
)
if not moderation_result:
kwargs['fake_response'] = "I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest."
if self.deduct_quota:
self.model_provider.check_quota_over_limit()
db.session.commit()
if not callbacks:
callbacks = self.callbacks
else:

View File

@ -17,6 +17,7 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
from models.provider import ProviderType, ProviderQuotaType
COMPLETION_MODELS = [
'gpt-3.5-turbo-instruct', # 4,096 tokens
'text-davinci-003', # 4,097 tokens
]
@ -31,6 +32,7 @@ MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}

View File

@ -18,6 +18,7 @@ class WenxinModel(BaseLLM):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
model=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,

View File

@ -0,0 +1,61 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
class ZhipuAIModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ZhipuAIChatLLM(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"ZhipuAI: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@ -0,0 +1,29 @@
from abc import abstractmethod
from typing import Any
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import BaseModelProvider
class BaseModeration(BaseProviderModel):
name: str
type: ModelType = ModelType.MODERATION
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
super().__init__(model_provider, client)
self.name = name
def run(self, text: str) -> bool:
try:
return self._run(text)
except Exception as ex:
raise self.handle_exceptions(ex)
@abstractmethod
def _run(self, text: str) -> bool:
raise NotImplementedError
@abstractmethod
def handle_exceptions(self, ex: Exception) -> Exception:
raise NotImplementedError

View File

@ -4,29 +4,39 @@ import openai
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.models.moderation.base import BaseModeration
from core.model_providers.providers.base import BaseModelProvider
DEFAULT_AUDIO_MODEL = 'whisper-1'
DEFAULT_MODEL = 'whisper-1'
class OpenAIModeration(BaseProviderModel):
type: ModelType = ModelType.MODERATION
class OpenAIModeration(BaseModeration):
def __init__(self, model_provider: BaseModelProvider, name: str):
super().__init__(model_provider, openai.Moderation)
super().__init__(model_provider, openai.Moderation, name)
def run(self, text):
def _run(self, text: str) -> bool:
credentials = self.model_provider.get_model_credentials(
model_name=DEFAULT_AUDIO_MODEL,
model_name=self.name,
model_type=self.type
)
try:
return self._client.create(input=text, api_key=credentials['openai_api_key'])
except Exception as ex:
raise self.handle_exceptions(ex)
# 2000 text per chunk
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for text_chunk in chunks:
moderation_result = self._client.create(input=text_chunk,
api_key=credentials['openai_api_key'])
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):

View File

@ -69,11 +69,11 @@ class AnthropicProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256),
max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256, precision=0),
)
@classmethod

View File

@ -164,14 +164,14 @@ class AzureOpenAIProvider(BaseModelProvider):
model_credentials = self.get_model_credentials(model_name, model_type)
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get(
model_credentials['base_model_name'],
4097
), default=16),
), default=16, precision=0),
)
@classmethod

View File

@ -64,11 +64,11 @@ class ChatGLMProvider(BaseModelProvider):
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
)
@classmethod

View File

@ -45,6 +45,18 @@ class HostedModelProviders(BaseModel):
hosted_model_providers = HostedModelProviders()
class HostedModerationConfig(BaseModel):
enabled: bool = False
providers: list[str] = []
class HostedConfig(BaseModel):
moderation = HostedModerationConfig()
hosted_config = HostedConfig()
def init_app(app: Flask):
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
langchain.verbose = True
@ -78,3 +90,9 @@ def init_app(app: Flask):
paid_min_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY"),
paid_max_quantity=app.config.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY"),
)
if app.config.get("HOSTED_MODERATION_ENABLED") and app.config.get("HOSTED_MODERATION_PROVIDERS"):
hosted_config.moderation = HostedModerationConfig(
enabled=app.config.get("HOSTED_MODERATION_ENABLED"),
providers=app.config.get("HOSTED_MODERATION_PROVIDERS").split(',')
)

View File

@ -1,5 +1,6 @@
import json
from typing import Type
import requests
from huggingface_hub import HfApi
@ -10,8 +11,12 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsVa
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from models.provider import ProviderType
HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'
class HuggingfaceHubProvider(BaseModelProvider):
@property
@ -33,6 +38,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = HuggingfaceHubModel
elif model_type == ModelType.EMBEDDINGS:
model_class = HuggingfaceEmbedding
else:
raise NotImplementedError
@ -47,11 +54,11 @@ class HuggingfaceHubProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=200, precision=0),
)
@classmethod
@ -63,7 +70,7 @@ class HuggingfaceHubProvider(BaseModelProvider):
:param model_type:
:param credentials:
"""
if model_type != ModelType.TEXT_GENERATION:
if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]:
raise NotImplementedError
if 'huggingfacehub_api_type' not in credentials \
@ -88,19 +95,15 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
if credentials['task_type'] not in ("text2text-generation", "text-generation", 'feature-extraction'):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
'text-generation, summarization.')
'text-generation, feature-extraction.')
try:
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
llm("ping")
if credentials['task_type'] == 'feature-extraction':
cls.check_embedding_valid(credentials, model_name)
else:
cls.check_llm_valid(credentials)
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
else:
@ -112,13 +115,64 @@ class HuggingfaceHubProvider(BaseModelProvider):
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
VALID_TASKS = ("text2text-generation", "text-generation", "feature-extraction")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {VALID_TASKS}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
@classmethod
def check_llm_valid(cls, credentials: dict):
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)
llm("ping")
@classmethod
def check_embedding_valid(cls, credentials: dict, model_name: str):
cls.check_endpoint_url_model_repository_name(credentials, model_name)
embedding_model = HuggingfaceHubEmbeddings(
model=model_name,
**credentials
)
embedding_model.embed_query("ping")
@classmethod
def check_endpoint_url_model_repository_name(cls, credentials: dict, model_name: str):
try:
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
headers = {
'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
'Content-Type': 'application/json'
}
response =requests.get(url=url, headers=headers)
if response.status_code != 200:
raise ValueError('User Name or Organization Name is invalid.')
model_repository_name = ''
for item in response.json().get("items", []):
if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
model_repository_name = item.get("model", {}).get("repository")
break
if model_repository_name != model_name:
raise ValueError(f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')
except Exception as e:
raise ValueError(str(e))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:

View File

@ -52,9 +52,9 @@ class LocalAIProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
temperature=KwargRule[float](min=0, max=2, default=0.7, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
max_tokens=KwargRule[int](min=10, max=4097, default=16, precision=0),
)
@classmethod

View File

@ -74,11 +74,11 @@ class MinimaxProvider(BaseModelProvider):
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.9),
top_p=KwargRule[float](min=0, max=1, default=0.95),
temperature=KwargRule[float](min=0.01, max=1, default=0.9, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.95, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024, precision=0),
)
@classmethod

View File

@ -40,6 +40,10 @@ class OpenAIProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
@ -128,16 +132,17 @@ class OpenAIProvider(BaseModelProvider):
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=1),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16),
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=1, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16, precision=0),
)
@classmethod

View File

@ -45,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128, precision=0),
)
@classmethod

View File

@ -72,6 +72,7 @@ class ReplicateProvider(BaseModelProvider):
min=float(value.get('minimum')) if value.get('minimum') is not None else None,
max=float(value.get('maximum')) if value.get('maximum') is not None else None,
default=float(value.get('default')) if value.get('default') is not None else None,
precision = 2
)
if key == 'temperature':
model_kwargs_rules.temperature = kwarg_rule
@ -84,6 +85,7 @@ class ReplicateProvider(BaseModelProvider):
min=int(value.get('minimum')) if value.get('minimum') is not None else 1,
max=int(value.get('maximum')) if value.get('maximum') is not None else 8000,
default=int(value.get('default')) if value.get('default') is not None else 500,
precision = 0
)
return model_kwargs_rules

View File

@ -62,11 +62,11 @@ class SparkProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.5),
temperature=KwargRule[float](min=0, max=1, default=0.5, precision=2),
top_p=KwargRule[float](enabled=False),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4096, default=2048),
max_tokens=KwargRule[int](min=10, max=4096, default=2048, precision=0),
)
@classmethod

View File

@ -64,10 +64,10 @@ class TongyiProvider(BaseModelProvider):
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](min=0, max=1, default=0.8),
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
)
@classmethod

View File

@ -61,13 +61,18 @@ class WenxinProvider(BaseModelProvider):
:param model_type:
:return:
"""
model_max_tokens = {
'ernie-bot': 4800,
'ernie-bot-turbo': 11200,
}
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95),
top_p=KwargRule[float](min=0.01, max=1, default=0.8),
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
max_tokens=KwargRule[int](enabled=False, max=model_max_tokens.get(model_name)),
)
else:
return ModelKwargsRules(

View File

@ -2,6 +2,7 @@ import json
from typing import Type
import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@ -52,27 +53,27 @@ class XinferenceProvider(BaseModelProvider):
credentials = self.get_model_credentials(model_name, model_type)
if credentials['model_format'] == "ggmlv3" and credentials["model_handle_type"] == "chatglm":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
elif credentials['model_format'] == "ggmlv3":
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0, precision=2),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
else:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
temperature=KwargRule[float](min=0.01, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=4000, default=256),
max_tokens=KwargRule[int](min=10, max=4000, default=256, precision=0),
)
@ -97,11 +98,18 @@ class XinferenceProvider(BaseModelProvider):
'model_uid': credentials['model_uid'],
}
llm = XinferenceLLM(
**credential_kwargs
)
if model_type == ModelType.TEXT_GENERATION:
llm = XinferenceLLM(
**credential_kwargs
)
llm("ping")
llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = XinferenceEmbeddings(
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@ -117,8 +125,9 @@ class XinferenceProvider(BaseModelProvider):
:param credentials:
:return:
"""
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
if model_type == ModelType.TEXT_GENERATION:
extra_credentials = cls._get_extra_credentials(credentials)
credentials.update(extra_credentials)
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])

View File

@ -0,0 +1,176 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
from models.provider import ProviderType, ProviderQuotaType
class ZhipuAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'zhipuai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
}
]
elif model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'text_embedding',
'name': 'text_embedding',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = ZhipuAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = ZhipuAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.1, max=0.9, default=0.8, precision=1),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('ZhipuAI api_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key']
}
llm = ZhipuAIChatLLM(
temperature=0.01,
**credential_kwargs
)
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value \
or (self.provider.provider_type == ProviderType.SYSTEM.value
and self.provider.quota_type == ProviderQuotaType.FREE.value):
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@ -6,6 +6,7 @@
"tongyi",
"spark",
"wenxin",
"zhipuai",
"chatglm",
"replicate",
"huggingface_hub",

View File

@ -30,6 +30,12 @@
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-instruct": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-16k": {
"prompt": "0.003",
"completion": "0.004",

View File

@ -0,0 +1,44 @@
{
"support_provider_types": [
"system",
"custom"
],
"system_config": {
"supported_quota_types": [
"free"
],
"quota_unit": "tokens"
},
"model_flexibility": "fixed",
"price_config": {
"chatglm_pro": {
"prompt": "0.01",
"completion": "0.01",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_std": {
"prompt": "0.005",
"completion": "0.005",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_lite": {
"prompt": "0.002",
"completion": "0.002",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_lite_32k": {
"prompt": "0.0004",
"completion": "0.0004",
"unit": "0.001",
"currency": "RMB"
},
"text_embedding": {
"completion": "0",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@ -1,6 +1,7 @@
import math
from typing import Optional
from flask import current_app
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
@ -12,7 +13,7 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
@ -26,6 +27,7 @@ from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
from models.provider import ProviderType
class OrchestratorRuleParser:
@ -63,7 +65,7 @@ class OrchestratorRuleParser:
# add agent callback to record agent thoughts
agent_callback = AgentLoopGatherCallbackHandler(
model_instant=agent_model_instance,
model_instance=agent_model_instance,
conversation_message_task=conversation_message_task
)
@ -123,23 +125,45 @@ class OrchestratorRuleParser:
return chain
def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
-> Optional[SensitiveWordAvoidanceChain]:
"""
Convert app sensitive word avoidance config to chain
:param model_instance: model instance
:param callbacks: callbacks for the chain
:param kwargs:
:return:
"""
if not self.app_model_config.sensitive_word_avoidance_dict:
return None
sensitive_word_avoidance_rule = None
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
if self.app_model_config.sensitive_word_avoidance_dict:
sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_config.get("enabled", False):
if sensitive_word_avoidance_config.get('type') == 'moderation':
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.MODERATION,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
)
else:
sensitive_words = sensitive_word_avoidance_config.get("words", "")
if sensitive_words:
sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
canned_response=sensitive_word_avoidance_config.get("canned_response")
if sensitive_word_avoidance_config.get("canned_response")
else 'Your content violates our usage policy. Please revise and try again.',
extra_params={
'sensitive_words': sensitive_words.split(','),
}
)
if sensitive_word_avoidance_rule:
return SensitiveWordAvoidanceChain(
sensitive_words=sensitive_words.split(","),
canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
model_instance=model_instance,
sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
output_key="sensitive_word_avoidance_output",
callbacks=callbacks,
**kwargs

View File

@ -0,0 +1,74 @@
from typing import Any, Dict, List, Optional
import json
import numpy as np
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from huggingface_hub import InferenceClient
HOSTED_INFERENCE_API = 'hosted_inference_api'
INFERENCE_ENDPOINTS = 'inference_endpoints'
class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
client: Any
model: str
huggingface_namespace: Optional[str] = None
task_type: Optional[str] = None
huggingfacehub_api_type: Optional[str] = None
huggingfacehub_api_token: Optional[str] = None
huggingfacehub_endpoint_url: Optional[str] = None
class Config:
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values['huggingfacehub_api_token'] = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
values['client'] = InferenceClient(token=values['huggingfacehub_api_token'])
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
model = ''
if self.huggingfacehub_api_type == HOSTED_INFERENCE_API:
model = self.model
else:
model = self.huggingfacehub_endpoint_url
output = self.client.post(
json={
"inputs": texts,
"options": {
"wait_for_model": False,
"use_cache": False
}
}, model=model)
embeddings = json.loads(output.decode())
return self.mean_pooling(embeddings)
def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]
# https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
# Returned values are a list of floats, or a list of list of floats
# (depending on if you sent a string or a list of string,
# and if the automatic reduction, usually mean_pooling for instance was applied for you or not.
# This should be explained on the model's README.)
def mean_pooling(self, embeddings: List) -> List[float]:
# If automatic reduction by giving model, no need to mean_pooling.
# For example one: List[List[float]]
if not isinstance(embeddings[0][0], list):
return embeddings
# For example two: List[List[List[float]]], need to mean_pooling.
sentence_embeddings = [np.mean(embedding[0], axis=0).tolist() for embedding in embeddings]
return sentence_embeddings

View File

@ -0,0 +1,64 @@
"""Wrapper around ZhipuAI embedding models."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from core.third_party.langchain.llms.zhipuai_llm import ZhipuModelAPI
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""Wrapper around ZhipuAI embedding models.
1024 dimensions.
"""
client: Any #: :meta private:
model: str
"""Model name to use."""
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "ZHIPUAI_API_KEY"
)
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
response = self.client.invoke(model=self.model, prompt=text)
data = response["data"]
embeddings.append(data.get('embedding'))
return [list(map(float, e)) for e in embeddings]
def embed_query(self, text: str) -> List[float]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

View File

@ -16,7 +16,7 @@ class HuggingFaceHubLLM(HuggingFaceHub):
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Only supports `text-generation`, `text2text-generation` and `summarization` for now.
Only supports `text-generation`, `text2text-generation` for now.
Example:
.. code-block:: python

View File

@ -14,6 +14,9 @@ class EnhanceOpenAI(OpenAI):
max_retries: int = 1
"""Maximum number of retries to make when generating."""
def __new__(cls, **data: Any): # type: ignore
return super(EnhanceOpenAI, cls).__new__(cls)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""

View File

@ -0,0 +1,315 @@
"""Wrapper around ZhipuAI APIs."""
from __future__ import annotations
import json
import logging
import posixpath
from typing import (
Any,
Dict,
List,
Optional, Iterator, Sequence,
)
import zhipuai
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import Extra, root_validator, BaseModel
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.utils import get_from_dict_or_env
from zhipuai.model_api.api import InvokeType
from zhipuai.utils import jwt_token
from zhipuai.utils.http_client import post, stream
from zhipuai.utils.sse_client import SSEClient
logger = logging.getLogger(__name__)
class ZhipuModelAPI(BaseModel):
base_url: str
api_key: str
api_timeout_seconds = 60
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SYNC)
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
if not response['success']:
raise ValueError(
f"Error Code: {response['code']}, Message: {response['msg']} "
)
return response
def sse_invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SSE)
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
return SSEClient(data)
def _build_api_url(self, kwargs, *path):
if kwargs:
if "model" not in kwargs:
raise Exception("model param missed")
model = kwargs.pop("model")
else:
model = "-"
return posixpath.join(self.base_url, model, *path)
def _generate_token(self):
if not self.api_key:
raise Exception(
"api_key not provided, you could provide it."
)
try:
return jwt_token.generate_token(self.api_key)
except Exception:
raise ValueError(
f"Your api_key is invalid, please check it."
)
class ZhipuAIChatLLM(BaseChatModel):
"""Wrapper around ZhipuAI large language models.
To use, you should pass the api_key as a named parameter to the constructor.
Example:
.. code-block:: python
from core.third_party.langchain.llms.zhipuai import ZhipuAI
model = ZhipuAI(model="<model_name>", api_key="my-api-key")
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any = None #: :meta private:
model: str = "chatglm_lite"
"""Model name to use."""
temperature: float = 0.95
"""A non-negative float that tunes the degree of randomness in generation."""
top_p: float = 0.7
"""Total probability mass of tokens to consider at each step."""
streaming: bool = False
"""Whether to stream the response or return it all at once."""
api_key: Optional[str] = None
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "ZHIPUAI_API_KEY"
)
if 'test' in values['base_url']:
values['model'] = 'chatglm_130b_test'
values['client'] = ZhipuModelAPI(api_key=values['api_key'], base_url=values['base_url'])
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "zhipuai"
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
dict_messages = []
for m in messages:
message = self._convert_message_to_dict(m)
if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)
return dict_messages
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
continue
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts = self._create_message_dicts(messages)
request = self._default_params
request["prompt"] = message_dicts
request.update(kwargs)
response = self.client.invoke(**request)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages)
request = self._default_params
request["prompt"] = message_dicts
request.update(kwargs)
for event in self.client.sse_invoke(incremental=True, **request).events():
if event.event == "add":
yield ChatGenerationChunk(message=AIMessageChunk(content=event.data))
if run_manager:
run_manager.on_llm_new_token(event.data)
elif event.event == "error" or event.event == "interrupted":
raise ValueError(
f"{event.data}"
)
elif event.event == "finish":
meta = json.loads(event.meta)
token_usage = meta['usage']
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
yield ChatGenerationChunk(
message=AIMessageChunk(content=event.data),
generation_info=dict({'token_usage': token_usage})
)
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
data = response["data"]
generations = []
for res in data["choices"]:
message = self._convert_dict_to_message(res)
gen = ChatGeneration(
message=message
)
generations.append(gen)
token_usage = data.get("usage")
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
# def get_token_ids(self, text: str) -> List[int]:
# """Return the ordered ids of the tokens in a text.
#
# Args:
# text: The string input to tokenize.
#
# Returns:
# A list of ids corresponding to the tokens in the text, in order they occur
# in the text.
# """
# from core.third_party.transformers.Token import ChatGLMTokenizer
#
# tokenizer = ChatGLMTokenizer.from_pretrained("THUDM/chatglm2-6b")
# return tokenizer.encode(text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}

View File

@ -33,7 +33,6 @@ class DatasetRetrieverTool(BaseTool):
return_resource: str
retriever_from: str
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description
@ -94,7 +93,10 @@ class DatasetRetrieverTool(BaseTool):
query,
search_type='similarity_score_threshold',
search_kwargs={
'k': self.k
'k': self.k,
'filter': {
'group_id': [dataset.id]
}
}
)
else:

View File

@ -46,6 +46,11 @@ class QdrantVectorStore(Qdrant):
self.client.delete_collection(collection_name=self.collection_name)
def delete_group(self):
self._reload_if_needed()
self.client.delete_collection(collection_name=self.collection_name)
@classmethod
def _document_from_scored_point(
cls,

View File

@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task
@dataset_was_deleted.connect
def handle(sender, **kwargs):
dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct)
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
dataset.index_struct, dataset.collection_binding_id)

View File

@ -1,174 +0,0 @@
import redis
from redis.connection import SSLConnection, Connection
from flask import request
from flask_session import Session, SqlAlchemySessionInterface, RedisSessionInterface
from flask_session.sessions import total_seconds
from itsdangerous import want_bytes
from extensions.ext_database import db
sess = Session()
def init_app(app):
sqlalchemy_session_interface = CustomSqlAlchemySessionInterface(
app,
db,
app.config.get('SESSION_SQLALCHEMY_TABLE', 'sessions'),
app.config.get('SESSION_KEY_PREFIX', 'session:'),
app.config.get('SESSION_USE_SIGNER', False),
app.config.get('SESSION_PERMANENT', True)
)
session_type = app.config.get('SESSION_TYPE')
if session_type == 'sqlalchemy':
app.session_interface = sqlalchemy_session_interface
elif session_type == 'redis':
connection_class = Connection
if app.config.get('SESSION_REDIS_USE_SSL', False):
connection_class = SSLConnection
sess_redis_client = redis.Redis()
sess_redis_client.connection_pool = redis.ConnectionPool(**{
'host': app.config.get('SESSION_REDIS_HOST', 'localhost'),
'port': app.config.get('SESSION_REDIS_PORT', 6379),
'username': app.config.get('SESSION_REDIS_USERNAME', None),
'password': app.config.get('SESSION_REDIS_PASSWORD', None),
'db': app.config.get('SESSION_REDIS_DB', 2),
'encoding': 'utf-8',
'encoding_errors': 'strict',
'decode_responses': False
}, connection_class=connection_class)
app.extensions['session_redis'] = sess_redis_client
app.session_interface = CustomRedisSessionInterface(
sess_redis_client,
app.config.get('SESSION_KEY_PREFIX', 'session:'),
app.config.get('SESSION_USE_SIGNER', False),
app.config.get('SESSION_PERMANENT', True)
)
class CustomSqlAlchemySessionInterface(SqlAlchemySessionInterface):
def __init__(
self,
app,
db,
table,
key_prefix,
use_signer=False,
permanent=True,
sequence=None,
autodelete=False,
):
if db is None:
from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy(app)
self.db = db
self.key_prefix = key_prefix
self.use_signer = use_signer
self.permanent = permanent
self.autodelete = autodelete
self.sequence = sequence
self.has_same_site_capability = hasattr(self, "get_cookie_samesite")
class Session(self.db.Model):
__tablename__ = table
if sequence:
id = self.db.Column( # noqa: A003, VNE003, A001
self.db.Integer, self.db.Sequence(sequence), primary_key=True
)
else:
id = self.db.Column( # noqa: A003, VNE003, A001
self.db.Integer, primary_key=True
)
session_id = self.db.Column(self.db.String(255), unique=True)
data = self.db.Column(self.db.LargeBinary)
expiry = self.db.Column(self.db.DateTime)
def __init__(self, session_id, data, expiry):
self.session_id = session_id
self.data = data
self.expiry = expiry
def __repr__(self):
return f"<Session data {self.data}>"
self.sql_session_model = Session
def save_session(self, *args, **kwargs):
if request.blueprint == 'service_api':
return
elif request.method == 'OPTIONS':
return
elif request.endpoint and request.endpoint == 'health':
return
return super().save_session(*args, **kwargs)
class CustomRedisSessionInterface(RedisSessionInterface):
def save_session(self, app, session, response):
if request.blueprint == 'service_api':
return
elif request.method == 'OPTIONS':
return
elif request.endpoint and request.endpoint == 'health':
return
if not self.should_set_cookie(app, session):
return
domain = self.get_cookie_domain(app)
path = self.get_cookie_path(app)
if not session:
if session.modified:
self.redis.delete(self.key_prefix + session.sid)
response.delete_cookie(
app.config["SESSION_COOKIE_NAME"], domain=domain, path=path
)
return
# Modification case. There are upsides and downsides to
# emitting a set-cookie header each request. The behavior
# is controlled by the :meth:`should_set_cookie` method
# which performs a quick check to figure out if the cookie
# should be set or not. This is controlled by the
# SESSION_REFRESH_EACH_REQUEST config flag as well as
# the permanent flag on the session itself.
# if not self.should_set_cookie(app, session):
# return
conditional_cookie_kwargs = {}
httponly = self.get_cookie_httponly(app)
secure = self.get_cookie_secure(app)
if self.has_same_site_capability:
conditional_cookie_kwargs["samesite"] = self.get_cookie_samesite(app)
expires = self.get_expiration_time(app, session)
if session.permanent:
value = self.serializer.dumps(dict(session))
if value is not None:
self.redis.setex(
name=self.key_prefix + session.sid,
value=value,
time=total_seconds(app.permanent_session_lifetime),
)
if self.use_signer:
session_id = self._get_signer(app).sign(want_bytes(session.sid)).decode("utf-8")
else:
session_id = session.sid
response.set_cookie(
app.config["SESSION_COOKIE_NAME"],
session_id,
expires=expires,
httponly=httponly,
domain=domain,
path=path,
secure=secure,
**conditional_cookie_kwargs,
)

0
api/fields/__init__.py Normal file
View File

138
api/fields/app_fields.py Normal file
View File

@ -0,0 +1,138 @@
from flask_restful import fields
from libs.helper import TimestampField
app_detail_kernel_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
}
related_app_list = {
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
'total': fields.Integer,
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
'model': fields.Raw(attribute='model_dict'),
'user_input_form': fields.Raw(attribute='user_input_form_list'),
'dataset_query_variable': fields.String,
'pre_prompt': fields.String,
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
}
app_detail_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'created_at': TimestampField
}
prompt_config_fields = {
'prompt_template': fields.String,
}
model_config_partial_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
app_partial_fields = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
'created_at': TimestampField
}
app_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
}
template_fields = {
'name': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'mode': fields.String,
'model_config': fields.Nested(model_config_fields),
}
template_list_fields = {
'data': fields.List(fields.Nested(template_fields)),
}
site_fields = {
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean,
'app_base_url': fields.String,
}
app_detail_fields_with_site = {
'id': fields.String,
'name': fields.String,
'mode': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'enable_site': fields.Boolean,
'enable_api': fields.Boolean,
'api_rpm': fields.Integer,
'api_rph': fields.Integer,
'is_demo': fields.Boolean,
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
'site': fields.Nested(site_fields),
'api_base_url': fields.String,
'created_at': TimestampField
}
app_site_fields = {
'app_id': fields.String,
'access_token': fields.String(attribute='code'),
'code': fields.String,
'title': fields.String,
'icon': fields.String,
'icon_background': fields.String,
'description': fields.String,
'default_language': fields.String,
'customize_domain': fields.String,
'copyright': fields.String,
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}

View File

@ -0,0 +1,182 @@
from flask_restful import fields
from libs.helper import TimestampField
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]['text'] if value else ''
account_fields = {
'id': fields.String,
'name': fields.String,
'email': fields.String
}
feedback_fields = {
'rating': fields.String,
'content': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account': fields.Nested(account_fields, allow_null=True),
}
annotation_fields = {
'content': fields.String,
'account': fields.Nested(account_fields, allow_null=True),
'created_at': TimestampField
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
'inputs': fields.Raw,
'query': fields.String,
'message': fields.Raw,
'message_tokens': fields.Integer,
'answer': fields.String,
'answer_tokens': fields.Integer,
'provider_response_latency': fields.Float,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
}
feedback_stat_fields = {
'like': fields.Integer,
'dislike': fields.Integer
}
model_config_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
'model': fields.Raw,
'user_input_form': fields.Raw,
'pre_prompt': fields.String,
'agent_mode': fields.Raw,
}
simple_configs_fields = {
'prompt_template': fields.String,
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
simple_message_detail_fields = {
'inputs': fields.Raw,
'query': fields.String,
'message': MessageTextField,
'answer': fields.String,
}
conversation_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String(),
'from_account_id': fields.String,
'read_at': TimestampField,
'created_at': TimestampField,
'annotation': fields.Nested(annotation_fields, allow_null=True),
'model_config': fields.Nested(simple_model_config_fields),
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
}
conversation_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
}
conversation_message_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'model_config': fields.Nested(model_config_fields),
'message': fields.Nested(message_detail_fields, attribute='first_message'),
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
conversation_with_summary_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_end_user_session_id': fields.String,
'from_account_id': fields.String,
'summary': fields.String(attribute='summary_or_query'),
'read_at': TimestampField,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(simple_model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
conversation_with_summary_pagination_fields = {
'page': fields.Integer,
'limit': fields.Integer(attribute='per_page'),
'total': fields.Integer,
'has_more': fields.Boolean(attribute='has_next'),
'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items')
}
conversation_detail_fields = {
'id': fields.String,
'status': fields.String,
'from_source': fields.String,
'from_end_user_id': fields.String,
'from_account_id': fields.String,
'created_at': TimestampField,
'annotated': fields.Boolean,
'model_config': fields.Nested(model_config_fields),
'message_count': fields.Integer,
'user_feedback_stats': fields.Nested(feedback_stat_fields),
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
}
simple_conversation_fields = {
'id': fields.String,
'name': fields.String,
'inputs': fields.Raw,
'status': fields.String,
'introduction': fields.String,
'created_at': TimestampField
}
conversation_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(simple_conversation_fields))
}
conversation_with_model_config_fields = {
**simple_conversation_fields,
'model_config': fields.Raw,
}
conversation_with_model_config_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
}

View File

@ -0,0 +1,65 @@
from flask_restful import fields
from libs.helper import TimestampField
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'is_bound': fields.Boolean,
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields))
}
integrate_notion_info_list_fields = {
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
}
integrate_icon_fields = {
'type': fields.String,
'url': fields.String,
'emoji': fields.String
}
integrate_page_fields = {
'page_name': fields.String,
'page_id': fields.String,
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
'parent_id': fields.String,
'type': fields.String
}
integrate_workspace_fields = {
'workspace_name': fields.String,
'workspace_id': fields.String,
'workspace_icon': fields.String,
'pages': fields.List(fields.Nested(integrate_page_fields)),
'total': fields.Integer
}
integrate_fields = {
'id': fields.String,
'provider': fields.String,
'created_at': TimestampField,
'is_bound': fields.Boolean,
'disabled': fields.Boolean,
'link': fields.String,
'source_info': fields.Nested(integrate_workspace_fields)
}
integrate_list_fields = {
'data': fields.List(fields.Nested(integrate_fields)),
}

View File

@ -0,0 +1,43 @@
from flask_restful import fields
from libs.helper import TimestampField
dataset_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
}
dataset_detail_fields = {
'id': fields.String,
'name': fields.String,
'description': fields.String,
'provider': fields.String,
'permission': fields.String,
'data_source_type': fields.String,
'indexing_technique': fields.String,
'app_count': fields.Integer,
'document_count': fields.Integer,
'word_count': fields.Integer,
'created_by': fields.String,
'created_at': TimestampField,
'updated_by': fields.String,
'updated_at': TimestampField,
'embedding_model': fields.String,
'embedding_model_provider': fields.String,
'embedding_available': fields.Boolean
}
dataset_query_detail_fields = {
"id": fields.String,
"content": fields.String,
"source": fields.String,
"source_app_id": fields.String,
"created_by_role": fields.String,
"created_by": fields.String,
"created_at": TimestampField
}

View File

@ -0,0 +1,76 @@
from flask_restful import fields
from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField
document_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'doc_form': fields.String,
}
document_with_segments_fields = {
'id': fields.String,
'position': fields.Integer,
'data_source_type': fields.String,
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
'dataset_process_rule_id': fields.String,
'name': fields.String,
'created_from': fields.String,
'created_by': fields.String,
'created_at': TimestampField,
'tokens': fields.Integer,
'indexing_status': fields.String,
'error': fields.String,
'enabled': fields.Boolean,
'disabled_at': TimestampField,
'disabled_by': fields.String,
'archived': fields.Boolean,
'display_status': fields.String,
'word_count': fields.Integer,
'hit_count': fields.Integer,
'completed_segments': fields.Integer,
'total_segments': fields.Integer
}
dataset_and_document_fields = {
'dataset': fields.Nested(dataset_fields),
'documents': fields.List(fields.Nested(document_fields)),
'batch': fields.String
}
document_status_fields = {
'id': fields.String,
'indexing_status': fields.String,
'processing_started_at': TimestampField,
'parsing_completed_at': TimestampField,
'cleaning_completed_at': TimestampField,
'splitting_completed_at': TimestampField,
'completed_at': TimestampField,
'paused_at': TimestampField,
'error': fields.String,
'stopped_at': TimestampField,
'completed_segments': fields.Integer,
'total_segments': fields.Integer,
}
document_status_fields_list = {
'data': fields.List(fields.Nested(document_status_fields))
}

Some files were not shown because too many files have changed in this diff Show More